fix nil pointer dereference on AWS errors
This commit is contained in:
parent
c2b9f79d7a
commit
e465b09624
2 changed files with 42 additions and 1 deletions
|
|
@ -276,7 +276,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
|
|||
}
|
||||
|
||||
res, err := c.client.GetShardIterator(ctx, params)
|
||||
return res.ShardIterator, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.ShardIterator, nil
|
||||
}
|
||||
|
||||
func isRetriableError(err error) bool {
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ package consumer
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||
|
|
@ -98,6 +100,42 @@ func TestScan(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestScan_GetShardIteratorError(t *testing.T) {
|
||||
mockError := errors.New("mock get shard iterator error")
|
||||
client := &kinesisClientMock{
|
||||
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||
return &kinesis.ListShardsOutput{
|
||||
Shards: []types.Shard{
|
||||
{ShardId: aws.String("myShard")},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return nil, mockError
|
||||
},
|
||||
}
|
||||
|
||||
// use cancel func to signal shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
var res string
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
cancel() // simulate cancellation while processing first record
|
||||
return nil
|
||||
}
|
||||
|
||||
c, err := New("myStreamName", WithClient(client))
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
err = c.Scan(ctx, fn)
|
||||
if !errors.Is(err, mockError) {
|
||||
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanShard(t *testing.T) {
|
||||
var client = &kinesisClientMock{
|
||||
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue