fix nil pointer dereference on AWS errors (#148)

* fix nil pointer dereference on AWS errors

* return Start errors to Scan consumer

before the previous commit e465b09, client errors panicked the
reader, so consumers would pick up sharditerator errors by virtue of
their server crashing and burning.

Now that client errors are properly returned, the behaviour of
listShards is problematic because it absorbs any client errors it gets.

The result of these two things now is that if you hit an aws error, your server will go into an
endless scan loop you can't detect and can't easily recover from.

To avoid that, listShards will now stop if it hits a client error.

---------

Co-authored-by: Jarrad Whitaker <jwhitaker 📧 swift-nav.com>
This commit is contained in:
Jarrad 2024-06-07 01:38:16 +10:00 committed by GitHub
parent 6720a01733
commit 553e2392fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 86 additions and 8 deletions

View file

@ -38,7 +38,7 @@ type AllGroup struct {
// Start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
// Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. i.e. If we miss a new shard
// while AWS is resharding we'll pick it up max 30 seconds later.
@ -51,12 +51,16 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
var ticker = time.NewTicker(30 * time.Second)
for {
g.findNewShards(ctx, shardc)
err := g.findNewShards(ctx, shardc)
if err != nil {
ticker.Stop()
return err
}
select {
case <-ctx.Done():
ticker.Stop()
return
return nil
case <-ticker.C:
}
}
@ -91,7 +95,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
// findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing
// a particular shard.
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error {
g.shardMu.Lock()
defer g.shardMu.Unlock()
@ -100,7 +104,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
shards, err := listShards(ctx, g.ksis, g.streamName)
if err != nil {
g.logger.Log("[GROUP] error:", err)
return
return err
}
// We do two `for` loops, since we have to set up all the `shardClosed`
@ -134,4 +138,5 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
}
}()
}
return nil
}

View file

@ -107,7 +107,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
)
go func() {
c.group.Start(ctx, shardc)
err := c.group.Start(ctx, shardc)
if err != nil {
errc <- fmt.Errorf("error starting scan: %w", err)
cancel()
}
<-ctx.Done()
close(shardc)
}()
@ -284,7 +288,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
}
res, err := c.client.GetShardIterator(ctx, params)
return res.ShardIterator, err
if err != nil {
return nil, err
}
return res.ShardIterator, nil
}
func isRetriableError(err error) bool {

View file

@ -2,6 +2,7 @@ package consumer
import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
@ -110,6 +111,71 @@ func TestScan(t *testing.T) {
}
}
func TestScan_ListShardsError(t *testing.T) {
mockError := errors.New("mock list shards error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return nil, mockError
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from listShards, but instead got %v", err)
}
}
func TestScan_GetShardIteratorError(t *testing.T) {
mockError := errors.New("mock get shard iterator error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{ShardId: aws.String("myShard")},
},
}, nil
},
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return nil, mockError
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
}
}
func TestScanShard(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {

View file

@ -8,7 +8,7 @@ import (
// Group interface used to manage which shard to process
type Group interface {
Start(ctx context.Context, shardc chan types.Shard)
Start(ctx context.Context, shardc chan types.Shard) error
GetCheckpoint(streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error
}