fix nil pointer dereference on AWS errors
This commit is contained in:
parent
c2b9f79d7a
commit
e465b09624
2 changed files with 42 additions and 1 deletions
|
|
@ -276,7 +276,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := c.client.GetShardIterator(ctx, params)
|
res, err := c.client.GetShardIterator(ctx, params)
|
||||||
return res.ShardIterator, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res.ShardIterator, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRetriableError(err error) bool {
|
func isRetriableError(err error) bool {
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,11 @@ package consumer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||||
|
|
@ -98,6 +100,42 @@ func TestScan(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScan_GetShardIteratorError(t *testing.T) {
|
||||||
|
mockError := errors.New("mock get shard iterator error")
|
||||||
|
client := &kinesisClientMock{
|
||||||
|
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||||
|
return &kinesis.ListShardsOutput{
|
||||||
|
Shards: []types.Shard{
|
||||||
|
{ShardId: aws.String("myShard")},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return nil, mockError
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// use cancel func to signal shutdown
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
|
||||||
|
var res string
|
||||||
|
var fn = func(r *Record) error {
|
||||||
|
res += string(r.Data)
|
||||||
|
cancel() // simulate cancellation while processing first record
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := New("myStreamName", WithClient(client))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.Scan(ctx, fn)
|
||||||
|
if !errors.Is(err, mockError) {
|
||||||
|
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestScanShard(t *testing.T) {
|
func TestScanShard(t *testing.T) {
|
||||||
var client = &kinesisClientMock{
|
var client = &kinesisClientMock{
|
||||||
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue