diff --git a/consumer.go b/consumer.go index 74b59bb..0bdc9a2 100644 --- a/consumer.go +++ b/consumer.go @@ -192,7 +192,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e c.logger.Log("[CONSUMER] get records error:", err.Error()) if !isRetriableError(err) { - return fmt.Errorf("get records error: %v", err.Error()) + return fmt.Errorf("get records error: %w", err) } shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) diff --git a/consumer_test.go b/consumer_test.go index d14b11e..8391b90 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -396,6 +396,7 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) { } func TestScanShard_GetRecordsError(t *testing.T) { + getRecordsError := &types.InvalidArgumentException{Message: aws.String("aws error message")} var client = &kinesisClientMock{ getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ @@ -406,8 +407,7 @@ func TestScanShard_GetRecordsError(t *testing.T) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: nil, - }, - &types.InvalidArgumentException{Message: aws.String("aws error message")} + }, getRecordsError }, } @@ -424,6 +424,10 @@ func TestScanShard_GetRecordsError(t *testing.T) { if err.Error() != "get records error: InvalidArgumentException: aws error message" { t.Fatalf("unexpected error: %v", err) } + + if !errors.Is(err, getRecordsError) { + t.Fatalf("unexpected error: %v", err) + } } type kinesisClientMock struct {