diff --git a/consumer.go b/consumer.go index bff0b21..6967c02 100644 --- a/consumer.go +++ b/consumer.go @@ -276,7 +276,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } res, err := c.client.GetShardIterator(ctx, params) - return res.ShardIterator, err + if err != nil { + return nil, err + } + return res.ShardIterator, nil } func isRetriableError(err error) bool { diff --git a/consumer_test.go b/consumer_test.go index 3330d32..58a4ce0 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -2,9 +2,11 @@ package consumer import ( "context" + "errors" "fmt" "sync" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kinesis" @@ -98,6 +100,42 @@ func TestScan(t *testing.T) { } } +func TestScan_GetShardIteratorError(t *testing.T) { + mockError := errors.New("mock get shard iterator error") + client := &kinesisClientMock{ + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: []types.Shard{ + {ShardId: aws.String("myShard")}, + }, + }, nil + }, + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return nil, mockError + }, + } + + // use cancel func to signal shutdown + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.Scan(ctx, fn) + if !errors.Is(err, mockError) { + t.Errorf("expected an error from getShardIterator, but instead got %v", err) + } +} + func TestScanShard(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {