From 58ce4ba9f5da66cc185feb2d78a91d400458c507 Mon Sep 17 00:00:00 2001 From: Alex Senger Date: Tue, 10 Sep 2024 17:22:50 +0200 Subject: [PATCH] #185 reverts upstream changes --- allgroup.go | 84 +++----------- consumer.go | 31 ++---- consumer_test.go | 282 +---------------------------------------------- group.go | 9 +- 4 files changed, 24 insertions(+), 382 deletions(-) diff --git a/allgroup.go b/allgroup.go index 6b98c95..328ab0d 100644 --- a/allgroup.go +++ b/allgroup.go @@ -2,7 +2,6 @@ package consumer import ( "context" - "fmt" "log/slog" "sync" "time" @@ -14,12 +13,11 @@ import ( // all shards on a stream func NewAllGroup(kinesis kinesisClient, store Store, streamName string, logger *slog.Logger) *AllGroup { return &AllGroup{ - kinesis: kinesis, - shards: make(map[string]types.Shard), - shardsClosed: make(map[string]chan struct{}), - streamName: streamName, - slog: logger, - Store: store, + kinesis: kinesis, + shards: make(map[string]types.Shard), + streamName: streamName, + slog: logger, + Store: store, } } @@ -32,14 +30,13 @@ type AllGroup struct { slog *slog.Logger Store - shardMu sync.Mutex - shards map[string]types.Shard - shardsClosed map[string]chan struct{} + shardMu sync.Mutex + shards map[string]types.Shard } // Start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { +func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { // Note: while ticker is a rather naive approach to this problem, // it actually simplifies a few things. i.e. If we miss a new shard // while AWS is re-sharding we'll pick it up max 30 seconds later. @@ -52,51 +49,21 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { var ticker = time.NewTicker(30 * time.Second) for { - err := g.findNewShards(ctx, shardc) - if err != nil { - ticker.Stop() - return err - } + g.findNewShards(ctx, shardc) select { case <-ctx.Done(): ticker.Stop() - return nil + return case <-ticker.C: } } } -func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error { - g.shardMu.Lock() - defer g.shardMu.Unlock() - c, ok := g.shardsClosed[shardID] - if !ok { - return fmt.Errorf("closing unknown shard ID %q", shardID) - } - close(c) - return nil -} - -func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { - if c == nil { - // no channel means we haven't seen this shard in listShards, so it - // probably fell off the TRIM_HORIZON, and we can assume it's fully processed. - return true - } - select { - case <-ctx.Done(): - return false - case <-c: - // the channel has been processed and closed by the consumer (CloseShard has been called) - return true - } -} - // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. -func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error { +func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { g.shardMu.Lock() defer g.shardMu.Unlock() @@ -105,39 +72,14 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e shards, err := listShards(ctx, g.kinesis, g.streamName) if err != nil { g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error())) - return err + return } - // We do two `for` loops, since we have to set up all the `shardClosed` - // channels before we start using any of them. It's highly probable - // that Kinesis provides us the shards in dependency order (parents - // before children), but it doesn't appear to be a guarantee. for _, shard := range shards { if _, ok := g.shards[*shard.ShardId]; ok { continue } g.shards[*shard.ShardId] = shard - g.shardsClosed[*shard.ShardId] = make(chan struct{}) + shardc <- shard } - for _, shard := range shards { - shard := shard // Shadow shard, since we use it in goroutine - var parent1, parent2 <-chan struct{} - if shard.ParentShardId != nil { - parent1 = g.shardsClosed[*shard.ParentShardId] - } - if shard.AdjacentParentShardId != nil { - parent2 = g.shardsClosed[*shard.AdjacentParentShardId] - } - go func() { - // Asynchronously wait for all parents of this shard to be processed - // before providing it out to our client. Kinesis guarantees that a - // given partition key's data will be provided to clients in-order, - // but when splits or joins happen, we need to process all parents prior - // to processing children or that ordering guarantee is not maintained. - if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) { - shardc <- shard - } - }() - } - return nil } diff --git a/consumer.go b/consumer.go index 9e63c8f..2e3a7b9 100644 --- a/consumer.go +++ b/consumer.go @@ -120,11 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { ) go func() { - err := c.group.Start(ctx, shardc) - if err != nil { - errc <- fmt.Errorf("error starting scan: %w", err) - cancel() - } + c.group.Start(ctx, shardc) <-ctx.Done() close(shardc) }() @@ -135,19 +131,13 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { wg.Add(1) go func(shardID string) { defer wg.Done() - var err error - if err = c.ScanShard(ctx, shardID, fn); err != nil { - err = fmt.Errorf("shard %s error: %w", shardID, err) - } else if closeable, ok := c.group.(CloseableGroup); !ok { - // group doesn't allow closure, skip calling CloseShard - } else if err = closeable.CloseShard(ctx, shardID); err != nil { - err = fmt.Errorf("shard closed CloseableGroup error: %w", err) - } - if err != nil { + if err := c.ScanShard(ctx, shardID, fn); err != nil { select { case errc <- fmt.Errorf("shard %s error: %w", shardID, err): + // first error to occur cancel() default: + // error has already occurred } } }(aws.ToString(shard.ShardId)) @@ -161,10 +151,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { return <-errc } -func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error { - return nil -} - // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { @@ -254,7 +240,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID)) if c.shardClosedHandler != nil { - if err := c.shardClosedHandler(c.streamName, shardID); err != nil { + err := c.shardClosedHandler(c.streamName, shardID) + + if err != nil { return fmt.Errorf("shard closed handler error: %w", err) } } @@ -306,10 +294,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } res, err := c.client.GetShardIterator(ctx, params) - if err != nil { - return nil, err - } - return res.ShardIterator, nil + return res.ShardIterator, err } func isRetriableError(err error) bool { diff --git a/consumer_test.go b/consumer_test.go index d0a3e13..81e6d48 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -2,12 +2,9 @@ package consumer import ( "context" - "errors" "fmt" - "math/rand" "sync" "testing" - "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kinesis" @@ -27,15 +24,6 @@ var records = []types.Record{ }, } -// Implement logger to wrap testing.T.Log. -type testLogger struct { - t *testing.T -} - -func (t *testLogger) Log(args ...interface{}) { - t.t.Log(args...) -} - func TestNew(t *testing.T) { if _, err := New("myStreamName"); err != nil { t.Fatalf("new consumer error: %v", err) @@ -72,7 +60,6 @@ func TestScan(t *testing.T) { WithClient(client), WithCounter(ctr), WithStore(cp), - WithLogger(&testLogger{t}), ) if err != nil { t.Fatalf("new consumer error: %v", err) @@ -111,71 +98,6 @@ func TestScan(t *testing.T) { } } -func TestScan_ListShardsError(t *testing.T) { - mockError := errors.New("mock list shards error") - client := &kinesisClientMock{ - listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { - return nil, mockError - }, - } - - // use cancel func to signal shutdown - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - - var res string - var fn = func(r *Record) error { - res += string(r.Data) - cancel() // simulate cancellation while processing first record - return nil - } - - c, err := New("myStreamName", WithClient(client)) - if err != nil { - t.Fatalf("new consumer error: %v", err) - } - - err = c.Scan(ctx, fn) - if !errors.Is(err, mockError) { - t.Errorf("expected an error from listShards, but instead got %v", err) - } -} - -func TestScan_GetShardIteratorError(t *testing.T) { - mockError := errors.New("mock get shard iterator error") - client := &kinesisClientMock{ - listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { - return &kinesis.ListShardsOutput{ - Shards: []types.Shard{ - {ShardId: aws.String("myShard")}, - }, - }, nil - }, - getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { - return nil, mockError - }, - } - - // use cancel func to signal shutdown - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - - var res string - var fn = func(r *Record) error { - res += string(r.Data) - cancel() // simulate cancellation while processing first record - return nil - } - - c, err := New("myStreamName", WithClient(client)) - if err != nil { - t.Fatalf("new consumer error: %v", err) - } - - err = c.Scan(ctx, fn) - if !errors.Is(err, mockError) { - t.Errorf("expected an error from getShardIterator, but instead got %v", err) - } -} - func TestScanShard(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { @@ -200,7 +122,6 @@ func TestScanShard(t *testing.T) { WithClient(client), WithCounter(ctr), WithStore(cp), - WithLogger(&testLogger{t}), ) if err != nil { t.Fatalf("new consumer error: %v", err) @@ -380,8 +301,7 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) { WithClient(client), WithShardClosedHandler(func(_, _ string) error { return fmt.Errorf("closed shard error") - }), - WithLogger(&testLogger{t})) + })) if err != nil { t.Fatalf("new consumer error: %v", err) } @@ -415,7 +335,7 @@ func TestScanShard_GetRecordsError(t *testing.T) { return nil } - c, err := New("myStreamName", WithClient(client), WithLogger(&testLogger{t})) + c, err := New("myStreamName", WithClient(client)) if err != nil { t.Fatalf("new consumer error: %v", err) } @@ -464,201 +384,3 @@ func (fc *fakeCounter) Add(_ string, count int64) { fc.counter += count } - -func TestScan_PreviousParentsBeforeTrimHorizon(t *testing.T) { - client := &kinesisClientMock{ - getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { - return &kinesis.GetShardIteratorOutput{ - ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), - }, nil - }, - getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { - return &kinesis.GetRecordsOutput{ - NextShardIterator: nil, - Records: records, - }, nil - }, - listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { - return &kinesis.ListShardsOutput{ - Shards: []types.Shard{ - { - ShardId: aws.String("myShard"), - ParentShardId: aws.String("myOldParent"), - AdjacentParentShardId: aws.String("myOldAdjacentParent"), - }, - }, - }, nil - }, - } - var ( - cp = store.New() - ctr = &fakeCounter{} - ) - - c, err := New("myStreamName", - WithClient(client), - WithCounter(ctr), - WithStore(cp), - WithLogger(&testLogger{t}), - ) - if err != nil { - t.Fatalf("new consumer error: %v", err) - } - - var ( - ctx, cancel = context.WithCancel(context.Background()) - res string - ) - - var fn = func(r *Record) error { - res += string(r.Data) - - if string(r.Data) == "lastData" { - cancel() - } - - return nil - } - - if err := c.Scan(ctx, fn); err != nil { - t.Errorf("scan returned unexpected error %v", err) - } - - if res != "firstDatalastData" { - t.Errorf("callback error expected %s, got %s", "firstDatalastData", res) - } - - if val := ctr.Get(); val != 2 { - t.Errorf("counter error expected %d, got %d", 2, val) - } - - val, err := cp.GetCheckpoint("myStreamName", "myShard") - if err != nil && val != "lastSeqNum" { - t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val) - } -} - -func TestScan_ParentChildOrdering(t *testing.T) { - // We create a set of shards where shard1 split into (shard2,shard3), then (shard2,shard3) merged into shard4. - client := &kinesisClientMock{ - getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { - return &kinesis.GetShardIteratorOutput{ - ShardIterator: aws.String(*params.ShardId + "iter"), - }, nil - }, - getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { - switch *params.ShardIterator { - case "shard1iter": - return &kinesis.GetRecordsOutput{ - NextShardIterator: nil, - Records: []types.Record{ - { - Data: []byte("shard1data"), - SequenceNumber: aws.String("shard1num"), - }, - }, - }, nil - case "shard2iter": - return &kinesis.GetRecordsOutput{ - NextShardIterator: nil, - Records: []types.Record{}, - }, nil - case "shard3iter": - return &kinesis.GetRecordsOutput{ - NextShardIterator: nil, - Records: []types.Record{ - { - Data: []byte("shard3data"), - SequenceNumber: aws.String("shard3num"), - }, - }, - }, nil - case "shard4iter": - return &kinesis.GetRecordsOutput{ - NextShardIterator: nil, - Records: []types.Record{ - { - Data: []byte("shard4data"), - SequenceNumber: aws.String("shard4num"), - }, - }, - }, nil - default: - panic("got unexpected iterator") - } - }, - listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { - // Intentionally misorder these to test resiliance to ordering issues from ListShards. - return &kinesis.ListShardsOutput{ - Shards: []types.Shard{ - { - ShardId: aws.String("shard3"), - ParentShardId: aws.String("shard1"), - }, - { - ShardId: aws.String("shard1"), - ParentShardId: aws.String("shard0"), // not otherwise referenced, parent ordering should ignore this - }, - { - ShardId: aws.String("shard4"), - ParentShardId: aws.String("shard2"), - AdjacentParentShardId: aws.String("shard3"), - }, - { - ShardId: aws.String("shard2"), - ParentShardId: aws.String("shard1"), - }, - }, - }, nil - }, - } - var ( - cp = store.New() - ctr = &fakeCounter{} - ) - - c, err := New("myStreamName", - WithClient(client), - WithCounter(ctr), - WithStore(cp), - WithLogger(&testLogger{t}), - ) - if err != nil { - t.Fatalf("new consumer error: %v", err) - } - - var ( - ctx, cancel = context.WithCancel(context.Background()) - res string - ) - - rand.Seed(time.Now().UnixNano()) - - var fn = func(r *Record) error { - res += string(r.Data) - time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond) - - if string(r.Data) == "shard4data" { - cancel() - } - - return nil - } - - if err := c.Scan(ctx, fn); err != nil { - t.Errorf("scan returned unexpected error %v", err) - } - - if want := "shard1datashard3datashard4data"; res != want { - t.Errorf("callback error expected %s, got %s", want, res) - } - - if val := ctr.Get(); val != 3 { - t.Errorf("counter error expected %d, got %d", 2, val) - } - - val, err := cp.GetCheckpoint("myStreamName", "shard4data") - if err != nil && val != "shard4num" { - t.Errorf("checkout error expected %s, got %s", "shard4num", val) - } -} diff --git a/group.go b/group.go index 29647d9..a092dc3 100644 --- a/group.go +++ b/group.go @@ -8,14 +8,7 @@ import ( // Group interface used to manage which shard to process type Group interface { - Start(ctx context.Context, shardc chan types.Shard) error + Start(ctx context.Context, shardc chan types.Shard) GetCheckpoint(streamName, shardID string) (string, error) SetCheckpoint(streamName, shardID, sequenceNumber string) error } - -type CloseableGroup interface { - Group - // Allows shard processors to tell the group when the shard has been - // fully processed. Should be called only once per shardID. - CloseShard(ctx context.Context, shardID string) error -}