diff --git a/consumer.go b/consumer.go index 698563d..507076b 100644 --- a/consumer.go +++ b/consumer.go @@ -8,6 +8,7 @@ import ( "log" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" @@ -145,13 +146,21 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e ShardIterator: shardIterator, }) - // attempt to recover from GetRecords error by getting new shard iterator + // attempt to recover from GetRecords error when expired iterator if err != nil { + c.logger.Log("[CONSUMER] get records error:", err.Error()) + + if awserr, ok := err.(awserr.Error); ok { + if _, ok := retriableErrors[awserr.Code()]; !ok { + return fmt.Errorf("get records error: %v", awserr.Message()) + } + } shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) if err != nil { return fmt.Errorf("get shard iterator error: %v", err) } + continue } @@ -187,6 +196,11 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } } +var retriableErrors = map[string]struct{}{ + kinesis.ErrCodeExpiredIteratorException: struct{}{}, + kinesis.ErrCodeProvisionedThroughputExceededException: struct{}{}, +} + func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } diff --git a/consumer_test.go b/consumer_test.go index 0a316ba..48696b6 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" ) @@ -276,6 +277,40 @@ func TestScanShard_ShardIsClosed(t *testing.T) { } } +func TestScanShard_GetRecordsError(t *testing.T) { + var client = &kinesisClientMock{ + getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: nil, + }, awserr.New( + kinesis.ErrCodeInvalidArgumentException, + "aws error message", + fmt.Errorf("error message"), + ) + }, + } + + var fn = func(r *Record) error { + return nil + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.ScanShard(context.Background(), "myShard", fn) + if err.Error() != "get records error: aws error message" { + t.Fatalf("unexpected error: %v", err) + } +} + type kinesisClientMock struct { kinesisiface.KinesisAPI getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)