Only retry expired shard iterator errors (#95)

Fixes https://github.com/harlow/kinesis-consumer/issues/92
This commit is contained in:
Harlow Ward 2019-07-30 19:48:20 -07:00 committed by GitHub
parent 5da0865ac1
commit 35c48ef1c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 1 deletions

View file

@ -8,6 +8,7 @@ import (
"log" "log"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
@ -145,13 +146,21 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
ShardIterator: shardIterator, ShardIterator: shardIterator,
}) })
// attempt to recover from GetRecords error by getting new shard iterator // attempt to recover from GetRecords error when expired iterator
if err != nil { if err != nil {
c.logger.Log("[CONSUMER] get records error:", err.Error())
if awserr, ok := err.(awserr.Error); ok {
if _, ok := retriableErrors[awserr.Code()]; !ok {
return fmt.Errorf("get records error: %v", awserr.Message())
}
}
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil { if err != nil {
return fmt.Errorf("get shard iterator error: %v", err) return fmt.Errorf("get shard iterator error: %v", err)
} }
continue continue
} }
@ -187,6 +196,11 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
} }
} }
var retriableErrors = map[string]struct{}{
kinesis.ErrCodeExpiredIteratorException: struct{}{},
kinesis.ErrCodeProvisionedThroughputExceededException: struct{}{},
}
func isShardClosed(nextShardIterator, currentShardIterator *string) bool { func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator return nextShardIterator == nil || currentShardIterator == nextShardIterator
} }

View file

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
) )
@ -276,6 +277,40 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
} }
} }
func TestScanShard_GetRecordsError(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: nil,
}, awserr.New(
kinesis.ErrCodeInvalidArgumentException,
"aws error message",
fmt.Errorf("error message"),
)
},
}
var fn = func(r *Record) error {
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.ScanShard(context.Background(), "myShard", fn)
if err.Error() != "get records error: aws error message" {
t.Fatalf("unexpected error: %v", err)
}
}
type kinesisClientMock struct { type kinesisClientMock struct {
kinesisiface.KinesisAPI kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)