fix nil pointer dereference on AWS errors

This commit is contained in:
Jarrad Whitaker 2022-10-31 09:14:50 +11:00
parent c2b9f79d7a
commit e465b09624
2 changed files with 42 additions and 1 deletions

View file

@ -276,7 +276,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
} }
res, err := c.client.GetShardIterator(ctx, params) res, err := c.client.GetShardIterator(ctx, params)
return res.ShardIterator, err if err != nil {
return nil, err
}
return res.ShardIterator, nil
} }
func isRetriableError(err error) bool { func isRetriableError(err error) bool {

View file

@ -2,9 +2,11 @@ package consumer
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync" "sync"
"testing" "testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/kinesis"
@ -98,6 +100,42 @@ func TestScan(t *testing.T) {
} }
} }
func TestScan_GetShardIteratorError(t *testing.T) {
mockError := errors.New("mock get shard iterator error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{ShardId: aws.String("myShard")},
},
}, nil
},
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return nil, mockError
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
}
}
func TestScanShard(t *testing.T) { func TestScanShard(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {