From 6720a01733e68abbf0f45585c6c6ded81526c363 Mon Sep 17 00:00:00 2001 From: gram-signal <84339875+gram-signal@users.noreply.github.com> Date: Thu, 6 Jun 2024 09:37:42 -0600 Subject: [PATCH 1/2] Maintain parent/child shard ordering across shard splits/merges. (#155) Kinesis allows clients to rely on an invariant that, for a given partition key, the order of records added to the stream will be maintained. IE: given an input `pkey=x,val=1 pkey=x,val=2 pkey=x,val=3`, the values `1,2,3` will be seen in that order when processed by clients, so long as clients are careful. It does so by putting all records for a single partition key into a single shard, then maintaining ordering within that shard. However, shards can be split and merge, to distribute load better and handle per-shard throughput limits. Kinesis does this currently by (one or many times) splitting a single shard into two or by merging two adjacent shards into one. When this occurs, Kinesis still allows for ordering consistency by detailing shard parent/child relationships within its `listShards` outputs. A split shard A will create children B and C, both with `ParentShardId=A`. A merging of shards A and B into C will create a new shard C with `ParentShardId=A,AdjacentParentShardId=B`. So long as clients fully process all records in parents (including adjacent parents) before processing the new shard, ordering will be maintained. `kinesis-consumer` currently doesn't do this. Instead, upon the initial (and subsequent) `listShards` call, all visible shards immediately begin processing. Considering this case, where shards split, then merge, and each shard `X` contains a single record `rX`: ``` time -> B / \ A D \ / C ``` record `rD` should be processed after both `rB` and `rC` are processed, and both `rB` and `rC` should wait for `rA` to be processed. By starting goroutines immediately, any ordering of `{rA,rB,rC,rD}` might occur within the original code. This PR utilizes the `AllGroup` as a book-keeper of fully processed shards, with the `Consumer` calling `CloseShard` once it has finished a shard. `AllGroup` doesn't release a shard for processing until its parents have fully been processed, and the consumer just processes the shards it receives as it used to. This PR created a new `CloseableGroup` interface rather than append to the existing `Group` interface to maintain backwards compatibility in existing code that may already implement the `Group` interface elsewhere. Different `Group` implementations don't get the ordering described above, but the default `Consumer` does. --- allgroup.go | 69 +++++++++++++-- consumer.go | 20 +++-- consumer_test.go | 216 ++++++++++++++++++++++++++++++++++++++++++++++- group.go | 7 ++ 4 files changed, 296 insertions(+), 16 deletions(-) diff --git a/allgroup.go b/allgroup.go index 749a380..d107a7a 100644 --- a/allgroup.go +++ b/allgroup.go @@ -2,6 +2,7 @@ package consumer import ( "context" + "fmt" "sync" "time" @@ -12,11 +13,12 @@ import ( // all shards on a stream func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { return &AllGroup{ - ksis: ksis, - shards: make(map[string]types.Shard), - streamName: streamName, - logger: logger, - Store: store, + ksis: ksis, + shards: make(map[string]types.Shard), + shardsClosed: make(map[string]chan struct{}), + streamName: streamName, + logger: logger, + Store: store, } } @@ -29,8 +31,9 @@ type AllGroup struct { logger Logger Store - shardMu sync.Mutex - shards map[string]types.Shard + shardMu sync.Mutex + shards map[string]types.Shard + shardsClosed map[string]chan struct{} } // Start is a blocking operation which will loop and attempt to find new @@ -59,6 +62,32 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { } } +func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error { + g.shardMu.Lock() + defer g.shardMu.Unlock() + c, ok := g.shardsClosed[shardID] + if !ok { + return fmt.Errorf("closing unknown shard ID %q", shardID) + } + close(c) + return nil +} + +func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { + if c == nil { + // no channel means we haven't seen this shard in listShards, so it + // probably fell off the TRIM_HORIZON, and we can assume it's fully processed. + return true + } + select { + case <-ctx.Done(): + return false + case <-c: + // the channel has been processed and closed by the consumer (CloseShard has been called) + return true + } +} + // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. @@ -74,11 +103,35 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { return } + // We do two `for` loops, since we have to set up all the `shardClosed` + // channels before we start using any of them. It's highly probable + // that Kinesis provides us the shards in dependency order (parents + // before children), but it doesn't appear to be a guarantee. for _, shard := range shards { if _, ok := g.shards[*shard.ShardId]; ok { continue } g.shards[*shard.ShardId] = shard - shardc <- shard + g.shardsClosed[*shard.ShardId] = make(chan struct{}) + } + for _, shard := range shards { + shard := shard // Shadow shard, since we use it in goroutine + var parent1, parent2 <-chan struct{} + if shard.ParentShardId != nil { + parent1 = g.shardsClosed[*shard.ParentShardId] + } + if shard.AdjacentParentShardId != nil { + parent2 = g.shardsClosed[*shard.AdjacentParentShardId] + } + go func() { + // Asynchronously wait for all parents of this shard to be processed + // before providing it out to our client. Kinesis guarantees that a + // given partition key's data will be provided to clients in-order, + // but when splits or joins happen, we need to process all parents prior + // to processing children or that ordering guarantee is not maintained. + if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) { + shardc <- shard + } + }() } } diff --git a/consumer.go b/consumer.go index bff0b21..8afb270 100644 --- a/consumer.go +++ b/consumer.go @@ -118,13 +118,19 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { wg.Add(1) go func(shardID string) { defer wg.Done() - if err := c.ScanShard(ctx, shardID, fn); err != nil { + var err error + if err = c.ScanShard(ctx, shardID, fn); err != nil { + err = fmt.Errorf("shard %s error: %w", shardID, err) + } else if closeable, ok := c.group.(CloseableGroup); !ok { + // group doesn't allow closure, skip calling CloseShard + } else if err = closeable.CloseShard(ctx, shardID); err != nil { + err = fmt.Errorf("shard closed CloseableGroup error: %w", err) + } + if err != nil { select { case errc <- fmt.Errorf("shard %s error: %w", shardID, err): - // first error to occur cancel() default: - // error has already occurred } } }(aws.ToString(shard.ShardId)) @@ -138,6 +144,10 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { return <-errc } +func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error { + return nil +} + // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { @@ -218,9 +228,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e c.logger.Log("[CONSUMER] shard closed:", shardID) if c.shardClosedHandler != nil { - err := c.shardClosedHandler(c.streamName, shardID) - - if err != nil { + if err := c.shardClosedHandler(c.streamName, shardID); err != nil { return fmt.Errorf("shard closed handler error: %w", err) } } diff --git a/consumer_test.go b/consumer_test.go index 3330d32..13cfb04 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -3,8 +3,10 @@ package consumer import ( "context" "fmt" + "math/rand" "sync" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kinesis" @@ -24,6 +26,15 @@ var records = []types.Record{ }, } +// Implement logger to wrap testing.T.Log. +type testLogger struct { + t *testing.T +} + +func (t *testLogger) Log(args ...interface{}) { + t.t.Log(args...) +} + func TestNew(t *testing.T) { if _, err := New("myStreamName"); err != nil { t.Fatalf("new consumer error: %v", err) @@ -60,6 +71,7 @@ func TestScan(t *testing.T) { WithClient(client), WithCounter(ctr), WithStore(cp), + WithLogger(&testLogger{t}), ) if err != nil { t.Fatalf("new consumer error: %v", err) @@ -122,6 +134,7 @@ func TestScanShard(t *testing.T) { WithClient(client), WithCounter(ctr), WithStore(cp), + WithLogger(&testLogger{t}), ) if err != nil { t.Fatalf("new consumer error: %v", err) @@ -301,7 +314,8 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) { WithClient(client), WithShardClosedHandler(func(streamName, shardID string) error { return fmt.Errorf("closed shard error") - })) + }), + WithLogger(&testLogger{t})) if err != nil { t.Fatalf("new consumer error: %v", err) } @@ -335,7 +349,7 @@ func TestScanShard_GetRecordsError(t *testing.T) { return nil } - c, err := New("myStreamName", WithClient(client)) + c, err := New("myStreamName", WithClient(client), WithLogger(&testLogger{t})) if err != nil { t.Fatalf("new consumer error: %v", err) } @@ -384,3 +398,201 @@ func (fc *fakeCounter) Add(streamName string, count int64) { fc.counter += count } + +func TestScan_PreviousParentsBeforeTrimHorizon(t *testing.T) { + client := &kinesisClientMock{ + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: records, + }, nil + }, + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: []types.Shard{ + { + ShardId: aws.String("myShard"), + ParentShardId: aws.String("myOldParent"), + AdjacentParentShardId: aws.String("myOldAdjacentParent"), + }, + }, + }, nil + }, + } + var ( + cp = store.New() + ctr = &fakeCounter{} + ) + + c, err := New("myStreamName", + WithClient(client), + WithCounter(ctr), + WithStore(cp), + WithLogger(&testLogger{t}), + ) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var ( + ctx, cancel = context.WithCancel(context.Background()) + res string + ) + + var fn = func(r *Record) error { + res += string(r.Data) + + if string(r.Data) == "lastData" { + cancel() + } + + return nil + } + + if err := c.Scan(ctx, fn); err != nil { + t.Errorf("scan returned unexpected error %v", err) + } + + if res != "firstDatalastData" { + t.Errorf("callback error expected %s, got %s", "firstDatalastData", res) + } + + if val := ctr.Get(); val != 2 { + t.Errorf("counter error expected %d, got %d", 2, val) + } + + val, err := cp.GetCheckpoint("myStreamName", "myShard") + if err != nil && val != "lastSeqNum" { + t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val) + } +} + +func TestScan_ParentChildOrdering(t *testing.T) { + // We create a set of shards where shard1 split into (shard2,shard3), then (shard2,shard3) merged into shard4. + client := &kinesisClientMock{ + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String(*params.ShardId + "iter"), + }, nil + }, + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { + switch *params.ShardIterator { + case "shard1iter": + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: []types.Record{ + { + Data: []byte("shard1data"), + SequenceNumber: aws.String("shard1num"), + }, + }, + }, nil + case "shard2iter": + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: []types.Record{}, + }, nil + case "shard3iter": + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: []types.Record{ + { + Data: []byte("shard3data"), + SequenceNumber: aws.String("shard3num"), + }, + }, + }, nil + case "shard4iter": + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: []types.Record{ + { + Data: []byte("shard4data"), + SequenceNumber: aws.String("shard4num"), + }, + }, + }, nil + default: + panic("got unexpected iterator") + } + }, + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + // Intentionally misorder these to test resiliance to ordering issues from ListShards. + return &kinesis.ListShardsOutput{ + Shards: []types.Shard{ + { + ShardId: aws.String("shard3"), + ParentShardId: aws.String("shard1"), + }, + { + ShardId: aws.String("shard1"), + ParentShardId: aws.String("shard0"), // not otherwise referenced, parent ordering should ignore this + }, + { + ShardId: aws.String("shard4"), + ParentShardId: aws.String("shard2"), + AdjacentParentShardId: aws.String("shard3"), + }, + { + ShardId: aws.String("shard2"), + ParentShardId: aws.String("shard1"), + }, + }, + }, nil + }, + } + var ( + cp = store.New() + ctr = &fakeCounter{} + ) + + c, err := New("myStreamName", + WithClient(client), + WithCounter(ctr), + WithStore(cp), + WithLogger(&testLogger{t}), + ) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var ( + ctx, cancel = context.WithCancel(context.Background()) + res string + ) + + rand.Seed(time.Now().UnixNano()) + + var fn = func(r *Record) error { + res += string(r.Data) + time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond) + + if string(r.Data) == "shard4data" { + cancel() + } + + return nil + } + + if err := c.Scan(ctx, fn); err != nil { + t.Errorf("scan returned unexpected error %v", err) + } + + if want := "shard1datashard3datashard4data"; res != want { + t.Errorf("callback error expected %s, got %s", want, res) + } + + if val := ctr.Get(); val != 3 { + t.Errorf("counter error expected %d, got %d", 2, val) + } + + val, err := cp.GetCheckpoint("myStreamName", "shard4data") + if err != nil && val != "shard4num" { + t.Errorf("checkout error expected %s, got %s", "shard4num", val) + } +} diff --git a/group.go b/group.go index a092dc3..5856f24 100644 --- a/group.go +++ b/group.go @@ -12,3 +12,10 @@ type Group interface { GetCheckpoint(streamName, shardID string) (string, error) SetCheckpoint(streamName, shardID, sequenceNumber string) error } + +type CloseableGroup interface { + Group + // Allows shard processors to tell the group when the shard has been + // fully processed. Should be called only once per shardID. + CloseShard(ctx context.Context, shardID string) error +} From 553e2392fdf3f9e8e7859481f915d2cfc60e1502 Mon Sep 17 00:00:00 2001 From: Jarrad <113399675+jwhitaker-swiftnav@users.noreply.github.com> Date: Fri, 7 Jun 2024 01:38:16 +1000 Subject: [PATCH 2/2] fix nil pointer dereference on AWS errors (#148) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix nil pointer dereference on AWS errors * return Start errors to Scan consumer before the previous commit e465b09, client errors panicked the reader, so consumers would pick up sharditerator errors by virtue of their server crashing and burning. Now that client errors are properly returned, the behaviour of listShards is problematic because it absorbs any client errors it gets. The result of these two things now is that if you hit an aws error, your server will go into an endless scan loop you can't detect and can't easily recover from. To avoid that, listShards will now stop if it hits a client error. --------- Co-authored-by: Jarrad Whitaker --- allgroup.go | 15 +++++++---- consumer.go | 11 ++++++-- consumer_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ group.go | 2 +- 4 files changed, 86 insertions(+), 8 deletions(-) diff --git a/allgroup.go b/allgroup.go index d107a7a..1ecb7b2 100644 --- a/allgroup.go +++ b/allgroup.go @@ -38,7 +38,7 @@ type AllGroup struct { // Start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { +func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { // Note: while ticker is a rather naive approach to this problem, // it actually simplifies a few things. i.e. If we miss a new shard // while AWS is resharding we'll pick it up max 30 seconds later. @@ -51,12 +51,16 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { var ticker = time.NewTicker(30 * time.Second) for { - g.findNewShards(ctx, shardc) + err := g.findNewShards(ctx, shardc) + if err != nil { + ticker.Stop() + return err + } select { case <-ctx.Done(): ticker.Stop() - return + return nil case <-ticker.C: } } @@ -91,7 +95,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. -func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { +func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error { g.shardMu.Lock() defer g.shardMu.Unlock() @@ -100,7 +104,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { shards, err := listShards(ctx, g.ksis, g.streamName) if err != nil { g.logger.Log("[GROUP] error:", err) - return + return err } // We do two `for` loops, since we have to set up all the `shardClosed` @@ -134,4 +138,5 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { } }() } + return nil } diff --git a/consumer.go b/consumer.go index 8afb270..80ab45c 100644 --- a/consumer.go +++ b/consumer.go @@ -107,7 +107,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { ) go func() { - c.group.Start(ctx, shardc) + err := c.group.Start(ctx, shardc) + if err != nil { + errc <- fmt.Errorf("error starting scan: %w", err) + cancel() + } <-ctx.Done() close(shardc) }() @@ -284,7 +288,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } res, err := c.client.GetShardIterator(ctx, params) - return res.ShardIterator, err + if err != nil { + return nil, err + } + return res.ShardIterator, nil } func isRetriableError(err error) bool { diff --git a/consumer_test.go b/consumer_test.go index 13cfb04..d14b11e 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -2,6 +2,7 @@ package consumer import ( "context" + "errors" "fmt" "math/rand" "sync" @@ -110,6 +111,71 @@ func TestScan(t *testing.T) { } } +func TestScan_ListShardsError(t *testing.T) { + mockError := errors.New("mock list shards error") + client := &kinesisClientMock{ + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return nil, mockError + }, + } + + // use cancel func to signal shutdown + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.Scan(ctx, fn) + if !errors.Is(err, mockError) { + t.Errorf("expected an error from listShards, but instead got %v", err) + } +} + +func TestScan_GetShardIteratorError(t *testing.T) { + mockError := errors.New("mock get shard iterator error") + client := &kinesisClientMock{ + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: []types.Shard{ + {ShardId: aws.String("myShard")}, + }, + }, nil + }, + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return nil, mockError + }, + } + + // use cancel func to signal shutdown + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.Scan(ctx, fn) + if !errors.Is(err, mockError) { + t.Errorf("expected an error from getShardIterator, but instead got %v", err) + } +} + func TestScanShard(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { diff --git a/group.go b/group.go index 5856f24..29647d9 100644 --- a/group.go +++ b/group.go @@ -8,7 +8,7 @@ import ( // Group interface used to manage which shard to process type Group interface { - Start(ctx context.Context, shardc chan types.Shard) + Start(ctx context.Context, shardc chan types.Shard) error GetCheckpoint(streamName, shardID string) (string, error) SetCheckpoint(streamName, shardID, sequenceNumber string) error }