From 553e2392fdf3f9e8e7859481f915d2cfc60e1502 Mon Sep 17 00:00:00 2001 From: Jarrad <113399675+jwhitaker-swiftnav@users.noreply.github.com> Date: Fri, 7 Jun 2024 01:38:16 +1000 Subject: [PATCH] fix nil pointer dereference on AWS errors (#148) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix nil pointer dereference on AWS errors * return Start errors to Scan consumer before the previous commit e465b09, client errors panicked the reader, so consumers would pick up sharditerator errors by virtue of their server crashing and burning. Now that client errors are properly returned, the behaviour of listShards is problematic because it absorbs any client errors it gets. The result of these two things now is that if you hit an aws error, your server will go into an endless scan loop you can't detect and can't easily recover from. To avoid that, listShards will now stop if it hits a client error. --------- Co-authored-by: Jarrad Whitaker --- allgroup.go | 15 +++++++---- consumer.go | 11 ++++++-- consumer_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ group.go | 2 +- 4 files changed, 86 insertions(+), 8 deletions(-) diff --git a/allgroup.go b/allgroup.go index d107a7a..1ecb7b2 100644 --- a/allgroup.go +++ b/allgroup.go @@ -38,7 +38,7 @@ type AllGroup struct { // Start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { +func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { // Note: while ticker is a rather naive approach to this problem, // it actually simplifies a few things. i.e. If we miss a new shard // while AWS is resharding we'll pick it up max 30 seconds later. @@ -51,12 +51,16 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { var ticker = time.NewTicker(30 * time.Second) for { - g.findNewShards(ctx, shardc) + err := g.findNewShards(ctx, shardc) + if err != nil { + ticker.Stop() + return err + } select { case <-ctx.Done(): ticker.Stop() - return + return nil case <-ticker.C: } } @@ -91,7 +95,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. -func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { +func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error { g.shardMu.Lock() defer g.shardMu.Unlock() @@ -100,7 +104,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { shards, err := listShards(ctx, g.ksis, g.streamName) if err != nil { g.logger.Log("[GROUP] error:", err) - return + return err } // We do two `for` loops, since we have to set up all the `shardClosed` @@ -134,4 +138,5 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { } }() } + return nil } diff --git a/consumer.go b/consumer.go index 8afb270..80ab45c 100644 --- a/consumer.go +++ b/consumer.go @@ -107,7 +107,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { ) go func() { - c.group.Start(ctx, shardc) + err := c.group.Start(ctx, shardc) + if err != nil { + errc <- fmt.Errorf("error starting scan: %w", err) + cancel() + } <-ctx.Done() close(shardc) }() @@ -284,7 +288,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } res, err := c.client.GetShardIterator(ctx, params) - return res.ShardIterator, err + if err != nil { + return nil, err + } + return res.ShardIterator, nil } func isRetriableError(err error) bool { diff --git a/consumer_test.go b/consumer_test.go index 13cfb04..d14b11e 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -2,6 +2,7 @@ package consumer import ( "context" + "errors" "fmt" "math/rand" "sync" @@ -110,6 +111,71 @@ func TestScan(t *testing.T) { } } +func TestScan_ListShardsError(t *testing.T) { + mockError := errors.New("mock list shards error") + client := &kinesisClientMock{ + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return nil, mockError + }, + } + + // use cancel func to signal shutdown + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.Scan(ctx, fn) + if !errors.Is(err, mockError) { + t.Errorf("expected an error from listShards, but instead got %v", err) + } +} + +func TestScan_GetShardIteratorError(t *testing.T) { + mockError := errors.New("mock get shard iterator error") + client := &kinesisClientMock{ + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: []types.Shard{ + {ShardId: aws.String("myShard")}, + }, + }, nil + }, + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return nil, mockError + }, + } + + // use cancel func to signal shutdown + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.Scan(ctx, fn) + if !errors.Is(err, mockError) { + t.Errorf("expected an error from getShardIterator, but instead got %v", err) + } +} + func TestScanShard(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { diff --git a/group.go b/group.go index 5856f24..29647d9 100644 --- a/group.go +++ b/group.go @@ -8,7 +8,7 @@ import ( // Group interface used to manage which shard to process type Group interface { - Start(ctx context.Context, shardc chan types.Shard) + Start(ctx context.Context, shardc chan types.Shard) error GetCheckpoint(streamName, shardID string) (string, error) SetCheckpoint(streamName, shardID, sequenceNumber string) error }