From 6720a01733e68abbf0f45585c6c6ded81526c363 Mon Sep 17 00:00:00 2001 From: gram-signal <84339875+gram-signal@users.noreply.github.com> Date: Thu, 6 Jun 2024 09:37:42 -0600 Subject: [PATCH] Maintain parent/child shard ordering across shard splits/merges. (#155) Kinesis allows clients to rely on an invariant that, for a given partition key, the order of records added to the stream will be maintained. IE: given an input `pkey=x,val=1 pkey=x,val=2 pkey=x,val=3`, the values `1,2,3` will be seen in that order when processed by clients, so long as clients are careful. It does so by putting all records for a single partition key into a single shard, then maintaining ordering within that shard. However, shards can be split and merge, to distribute load better and handle per-shard throughput limits. Kinesis does this currently by (one or many times) splitting a single shard into two or by merging two adjacent shards into one. When this occurs, Kinesis still allows for ordering consistency by detailing shard parent/child relationships within its `listShards` outputs. A split shard A will create children B and C, both with `ParentShardId=A`. A merging of shards A and B into C will create a new shard C with `ParentShardId=A,AdjacentParentShardId=B`. So long as clients fully process all records in parents (including adjacent parents) before processing the new shard, ordering will be maintained. `kinesis-consumer` currently doesn't do this. Instead, upon the initial (and subsequent) `listShards` call, all visible shards immediately begin processing. Considering this case, where shards split, then merge, and each shard `X` contains a single record `rX`: ``` time -> B / \ A D \ / C ``` record `rD` should be processed after both `rB` and `rC` are processed, and both `rB` and `rC` should wait for `rA` to be processed. By starting goroutines immediately, any ordering of `{rA,rB,rC,rD}` might occur within the original code. This PR utilizes the `AllGroup` as a book-keeper of fully processed shards, with the `Consumer` calling `CloseShard` once it has finished a shard. `AllGroup` doesn't release a shard for processing until its parents have fully been processed, and the consumer just processes the shards it receives as it used to. This PR created a new `CloseableGroup` interface rather than append to the existing `Group` interface to maintain backwards compatibility in existing code that may already implement the `Group` interface elsewhere. Different `Group` implementations don't get the ordering described above, but the default `Consumer` does. --- allgroup.go | 69 +++++++++++++-- consumer.go | 20 +++-- consumer_test.go | 216 ++++++++++++++++++++++++++++++++++++++++++++++- group.go | 7 ++ 4 files changed, 296 insertions(+), 16 deletions(-) diff --git a/allgroup.go b/allgroup.go index 749a380..d107a7a 100644 --- a/allgroup.go +++ b/allgroup.go @@ -2,6 +2,7 @@ package consumer import ( "context" + "fmt" "sync" "time" @@ -12,11 +13,12 @@ import ( // all shards on a stream func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { return &AllGroup{ - ksis: ksis, - shards: make(map[string]types.Shard), - streamName: streamName, - logger: logger, - Store: store, + ksis: ksis, + shards: make(map[string]types.Shard), + shardsClosed: make(map[string]chan struct{}), + streamName: streamName, + logger: logger, + Store: store, } } @@ -29,8 +31,9 @@ type AllGroup struct { logger Logger Store - shardMu sync.Mutex - shards map[string]types.Shard + shardMu sync.Mutex + shards map[string]types.Shard + shardsClosed map[string]chan struct{} } // Start is a blocking operation which will loop and attempt to find new @@ -59,6 +62,32 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { } } +func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error { + g.shardMu.Lock() + defer g.shardMu.Unlock() + c, ok := g.shardsClosed[shardID] + if !ok { + return fmt.Errorf("closing unknown shard ID %q", shardID) + } + close(c) + return nil +} + +func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { + if c == nil { + // no channel means we haven't seen this shard in listShards, so it + // probably fell off the TRIM_HORIZON, and we can assume it's fully processed. + return true + } + select { + case <-ctx.Done(): + return false + case <-c: + // the channel has been processed and closed by the consumer (CloseShard has been called) + return true + } +} + // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. @@ -74,11 +103,35 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { return } + // We do two `for` loops, since we have to set up all the `shardClosed` + // channels before we start using any of them. It's highly probable + // that Kinesis provides us the shards in dependency order (parents + // before children), but it doesn't appear to be a guarantee. for _, shard := range shards { if _, ok := g.shards[*shard.ShardId]; ok { continue } g.shards[*shard.ShardId] = shard - shardc <- shard + g.shardsClosed[*shard.ShardId] = make(chan struct{}) + } + for _, shard := range shards { + shard := shard // Shadow shard, since we use it in goroutine + var parent1, parent2 <-chan struct{} + if shard.ParentShardId != nil { + parent1 = g.shardsClosed[*shard.ParentShardId] + } + if shard.AdjacentParentShardId != nil { + parent2 = g.shardsClosed[*shard.AdjacentParentShardId] + } + go func() { + // Asynchronously wait for all parents of this shard to be processed + // before providing it out to our client. Kinesis guarantees that a + // given partition key's data will be provided to clients in-order, + // but when splits or joins happen, we need to process all parents prior + // to processing children or that ordering guarantee is not maintained. + if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) { + shardc <- shard + } + }() } } diff --git a/consumer.go b/consumer.go index bff0b21..8afb270 100644 --- a/consumer.go +++ b/consumer.go @@ -118,13 +118,19 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { wg.Add(1) go func(shardID string) { defer wg.Done() - if err := c.ScanShard(ctx, shardID, fn); err != nil { + var err error + if err = c.ScanShard(ctx, shardID, fn); err != nil { + err = fmt.Errorf("shard %s error: %w", shardID, err) + } else if closeable, ok := c.group.(CloseableGroup); !ok { + // group doesn't allow closure, skip calling CloseShard + } else if err = closeable.CloseShard(ctx, shardID); err != nil { + err = fmt.Errorf("shard closed CloseableGroup error: %w", err) + } + if err != nil { select { case errc <- fmt.Errorf("shard %s error: %w", shardID, err): - // first error to occur cancel() default: - // error has already occurred } } }(aws.ToString(shard.ShardId)) @@ -138,6 +144,10 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { return <-errc } +func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error { + return nil +} + // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { @@ -218,9 +228,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e c.logger.Log("[CONSUMER] shard closed:", shardID) if c.shardClosedHandler != nil { - err := c.shardClosedHandler(c.streamName, shardID) - - if err != nil { + if err := c.shardClosedHandler(c.streamName, shardID); err != nil { return fmt.Errorf("shard closed handler error: %w", err) } } diff --git a/consumer_test.go b/consumer_test.go index 3330d32..13cfb04 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -3,8 +3,10 @@ package consumer import ( "context" "fmt" + "math/rand" "sync" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kinesis" @@ -24,6 +26,15 @@ var records = []types.Record{ }, } +// Implement logger to wrap testing.T.Log. +type testLogger struct { + t *testing.T +} + +func (t *testLogger) Log(args ...interface{}) { + t.t.Log(args...) +} + func TestNew(t *testing.T) { if _, err := New("myStreamName"); err != nil { t.Fatalf("new consumer error: %v", err) @@ -60,6 +71,7 @@ func TestScan(t *testing.T) { WithClient(client), WithCounter(ctr), WithStore(cp), + WithLogger(&testLogger{t}), ) if err != nil { t.Fatalf("new consumer error: %v", err) @@ -122,6 +134,7 @@ func TestScanShard(t *testing.T) { WithClient(client), WithCounter(ctr), WithStore(cp), + WithLogger(&testLogger{t}), ) if err != nil { t.Fatalf("new consumer error: %v", err) @@ -301,7 +314,8 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) { WithClient(client), WithShardClosedHandler(func(streamName, shardID string) error { return fmt.Errorf("closed shard error") - })) + }), + WithLogger(&testLogger{t})) if err != nil { t.Fatalf("new consumer error: %v", err) } @@ -335,7 +349,7 @@ func TestScanShard_GetRecordsError(t *testing.T) { return nil } - c, err := New("myStreamName", WithClient(client)) + c, err := New("myStreamName", WithClient(client), WithLogger(&testLogger{t})) if err != nil { t.Fatalf("new consumer error: %v", err) } @@ -384,3 +398,201 @@ func (fc *fakeCounter) Add(streamName string, count int64) { fc.counter += count } + +func TestScan_PreviousParentsBeforeTrimHorizon(t *testing.T) { + client := &kinesisClientMock{ + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: records, + }, nil + }, + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: []types.Shard{ + { + ShardId: aws.String("myShard"), + ParentShardId: aws.String("myOldParent"), + AdjacentParentShardId: aws.String("myOldAdjacentParent"), + }, + }, + }, nil + }, + } + var ( + cp = store.New() + ctr = &fakeCounter{} + ) + + c, err := New("myStreamName", + WithClient(client), + WithCounter(ctr), + WithStore(cp), + WithLogger(&testLogger{t}), + ) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var ( + ctx, cancel = context.WithCancel(context.Background()) + res string + ) + + var fn = func(r *Record) error { + res += string(r.Data) + + if string(r.Data) == "lastData" { + cancel() + } + + return nil + } + + if err := c.Scan(ctx, fn); err != nil { + t.Errorf("scan returned unexpected error %v", err) + } + + if res != "firstDatalastData" { + t.Errorf("callback error expected %s, got %s", "firstDatalastData", res) + } + + if val := ctr.Get(); val != 2 { + t.Errorf("counter error expected %d, got %d", 2, val) + } + + val, err := cp.GetCheckpoint("myStreamName", "myShard") + if err != nil && val != "lastSeqNum" { + t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val) + } +} + +func TestScan_ParentChildOrdering(t *testing.T) { + // We create a set of shards where shard1 split into (shard2,shard3), then (shard2,shard3) merged into shard4. + client := &kinesisClientMock{ + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String(*params.ShardId + "iter"), + }, nil + }, + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { + switch *params.ShardIterator { + case "shard1iter": + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: []types.Record{ + { + Data: []byte("shard1data"), + SequenceNumber: aws.String("shard1num"), + }, + }, + }, nil + case "shard2iter": + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: []types.Record{}, + }, nil + case "shard3iter": + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: []types.Record{ + { + Data: []byte("shard3data"), + SequenceNumber: aws.String("shard3num"), + }, + }, + }, nil + case "shard4iter": + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: []types.Record{ + { + Data: []byte("shard4data"), + SequenceNumber: aws.String("shard4num"), + }, + }, + }, nil + default: + panic("got unexpected iterator") + } + }, + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + // Intentionally misorder these to test resiliance to ordering issues from ListShards. + return &kinesis.ListShardsOutput{ + Shards: []types.Shard{ + { + ShardId: aws.String("shard3"), + ParentShardId: aws.String("shard1"), + }, + { + ShardId: aws.String("shard1"), + ParentShardId: aws.String("shard0"), // not otherwise referenced, parent ordering should ignore this + }, + { + ShardId: aws.String("shard4"), + ParentShardId: aws.String("shard2"), + AdjacentParentShardId: aws.String("shard3"), + }, + { + ShardId: aws.String("shard2"), + ParentShardId: aws.String("shard1"), + }, + }, + }, nil + }, + } + var ( + cp = store.New() + ctr = &fakeCounter{} + ) + + c, err := New("myStreamName", + WithClient(client), + WithCounter(ctr), + WithStore(cp), + WithLogger(&testLogger{t}), + ) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var ( + ctx, cancel = context.WithCancel(context.Background()) + res string + ) + + rand.Seed(time.Now().UnixNano()) + + var fn = func(r *Record) error { + res += string(r.Data) + time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond) + + if string(r.Data) == "shard4data" { + cancel() + } + + return nil + } + + if err := c.Scan(ctx, fn); err != nil { + t.Errorf("scan returned unexpected error %v", err) + } + + if want := "shard1datashard3datashard4data"; res != want { + t.Errorf("callback error expected %s, got %s", want, res) + } + + if val := ctr.Get(); val != 3 { + t.Errorf("counter error expected %d, got %d", 2, val) + } + + val, err := cp.GetCheckpoint("myStreamName", "shard4data") + if err != nil && val != "shard4num" { + t.Errorf("checkout error expected %s, got %s", "shard4num", val) + } +} diff --git a/group.go b/group.go index a092dc3..5856f24 100644 --- a/group.go +++ b/group.go @@ -12,3 +12,10 @@ type Group interface { GetCheckpoint(streamName, shardID string) (string, error) SetCheckpoint(streamName, shardID, sequenceNumber string) error } + +type CloseableGroup interface { + Group + // Allows shard processors to tell the group when the shard has been + // fully processed. Should be called only once per shardID. + CloseShard(ctx context.Context, shardID string) error +}