diff --git a/consumer.go b/consumer.go index 72e6942..a01d9da 100644 --- a/consumer.go +++ b/consumer.go @@ -158,33 +158,26 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e defer scanTicker.Stop() for { - select { - case <-ctx.Done(): - return nil - case <-scanTicker.C: - resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{ - Limit: aws.Int64(c.maxRecords), - ShardIterator: shardIterator, - }) + resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{ + Limit: aws.Int64(c.maxRecords), + ShardIterator: shardIterator, + }) - // attempt to recover from GetRecords error when expired iterator - if err != nil { - c.logger.Log("[CONSUMER] get records error:", err.Error()) + // attempt to recover from GetRecords error when expired iterator + if err != nil { + c.logger.Log("[CONSUMER] get records error:", err.Error()) - if awserr, ok := err.(awserr.Error); ok { - if _, ok := retriableErrors[awserr.Code()]; !ok { - return fmt.Errorf("get records error: %v", awserr.Message()) - } + if awserr, ok := err.(awserr.Error); ok { + if _, ok := retriableErrors[awserr.Code()]; !ok { + return fmt.Errorf("get records error: %v", awserr.Message()) } - - shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) - if err != nil { - return fmt.Errorf("get shard iterator error: %v", err) - } - - continue } + shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) + if err != nil { + return fmt.Errorf("get shard iterator error: %v", err) + } + } else { // loop over records, call callback func for _, r := range resp.Records { select { @@ -214,6 +207,14 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e shardIterator = resp.NextShardIterator } + + // Wait for next scan + select { + case <-ctx.Done(): + return nil + case <-scanTicker.C: + continue + } } }