fix nil pointer dereference on AWS errors (#148)
* fix nil pointer dereference on AWS errors
* return Start errors to Scan consumer
before the previous commit e465b09, client errors panicked the
reader, so consumers would pick up sharditerator errors by virtue of
their server crashing and burning.
Now that client errors are properly returned, the behaviour of
listShards is problematic because it absorbs any client errors it gets.
The result of these two things now is that if you hit an aws error, your server will go into an
endless scan loop you can't detect and can't easily recover from.
To avoid that, listShards will now stop if it hits a client error.
---------
Co-authored-by: Jarrad Whitaker <jwhitaker 📧 swift-nav.com>
This commit is contained in:
parent
6720a01733
commit
553e2392fd
4 changed files with 86 additions and 8 deletions
15
allgroup.go
15
allgroup.go
|
|
@ -38,7 +38,7 @@ type AllGroup struct {
|
|||
|
||||
// Start is a blocking operation which will loop and attempt to find new
|
||||
// shards on a regular cadence.
|
||||
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
|
||||
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
|
||||
// Note: while ticker is a rather naive approach to this problem,
|
||||
// it actually simplifies a few things. i.e. If we miss a new shard
|
||||
// while AWS is resharding we'll pick it up max 30 seconds later.
|
||||
|
|
@ -51,12 +51,16 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
|
|||
var ticker = time.NewTicker(30 * time.Second)
|
||||
|
||||
for {
|
||||
g.findNewShards(ctx, shardc)
|
||||
err := g.findNewShards(ctx, shardc)
|
||||
if err != nil {
|
||||
ticker.Stop()
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
|
|
@ -91,7 +95,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
|
|||
// findNewShards pulls the list of shards from the Kinesis API
|
||||
// and uses a local cache to determine if we are already processing
|
||||
// a particular shard.
|
||||
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
|
||||
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error {
|
||||
g.shardMu.Lock()
|
||||
defer g.shardMu.Unlock()
|
||||
|
||||
|
|
@ -100,7 +104,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
|
|||
shards, err := listShards(ctx, g.ksis, g.streamName)
|
||||
if err != nil {
|
||||
g.logger.Log("[GROUP] error:", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
// We do two `for` loops, since we have to set up all the `shardClosed`
|
||||
|
|
@ -134,4 +138,5 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
|
|||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
11
consumer.go
11
consumer.go
|
|
@ -107,7 +107,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
)
|
||||
|
||||
go func() {
|
||||
c.group.Start(ctx, shardc)
|
||||
err := c.group.Start(ctx, shardc)
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("error starting scan: %w", err)
|
||||
cancel()
|
||||
}
|
||||
<-ctx.Done()
|
||||
close(shardc)
|
||||
}()
|
||||
|
|
@ -284,7 +288,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
|
|||
}
|
||||
|
||||
res, err := c.client.GetShardIterator(ctx, params)
|
||||
return res.ShardIterator, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.ShardIterator, nil
|
||||
}
|
||||
|
||||
func isRetriableError(err error) bool {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package consumer
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
|
|
@ -110,6 +111,71 @@ func TestScan(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestScan_ListShardsError(t *testing.T) {
|
||||
mockError := errors.New("mock list shards error")
|
||||
client := &kinesisClientMock{
|
||||
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||
return nil, mockError
|
||||
},
|
||||
}
|
||||
|
||||
// use cancel func to signal shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
var res string
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
cancel() // simulate cancellation while processing first record
|
||||
return nil
|
||||
}
|
||||
|
||||
c, err := New("myStreamName", WithClient(client))
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
err = c.Scan(ctx, fn)
|
||||
if !errors.Is(err, mockError) {
|
||||
t.Errorf("expected an error from listShards, but instead got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScan_GetShardIteratorError(t *testing.T) {
|
||||
mockError := errors.New("mock get shard iterator error")
|
||||
client := &kinesisClientMock{
|
||||
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||
return &kinesis.ListShardsOutput{
|
||||
Shards: []types.Shard{
|
||||
{ShardId: aws.String("myShard")},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return nil, mockError
|
||||
},
|
||||
}
|
||||
|
||||
// use cancel func to signal shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
var res string
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
cancel() // simulate cancellation while processing first record
|
||||
return nil
|
||||
}
|
||||
|
||||
c, err := New("myStreamName", WithClient(client))
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
err = c.Scan(ctx, fn)
|
||||
if !errors.Is(err, mockError) {
|
||||
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanShard(t *testing.T) {
|
||||
var client = &kinesisClientMock{
|
||||
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||
|
|
|
|||
2
group.go
2
group.go
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
// Group interface used to manage which shard to process
|
||||
type Group interface {
|
||||
Start(ctx context.Context, shardc chan types.Shard)
|
||||
Start(ctx context.Context, shardc chan types.Shard) error
|
||||
GetCheckpoint(streamName, shardID string) (string, error)
|
||||
SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue