From 8d10ac8dac5fdc0275c0d69d443f9bbd58d9c325 Mon Sep 17 00:00:00 2001 From: Mikhail Konovalov <4463812+mskonovalov@users.noreply.github.com> Date: Tue, 17 Sep 2024 05:25:49 +1000 Subject: [PATCH] Fix ProvisionedThroughputExceededException error (#161) Fixes #158. Seems the bug was introduced in #155. See #155 (comment) --- allgroup.go | 24 +++++++++++++----------- consumer.go | 50 +++++++++++++++++++++++++++----------------------- 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/allgroup.go b/allgroup.go index 1ecb7b2..0823536 100644 --- a/allgroup.go +++ b/allgroup.go @@ -9,7 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kinesis/types" ) -// NewAllGroup returns an intitialized AllGroup for consuming +// NewAllGroup returns an initialized AllGroup for consuming // all shards on a stream func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { return &AllGroup{ @@ -38,12 +38,12 @@ type AllGroup struct { // Start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { +func (g *AllGroup) Start(ctx context.Context, shardC chan types.Shard) error { // Note: while ticker is a rather naive approach to this problem, - // it actually simplifies a few things. i.e. If we miss a new shard - // while AWS is resharding we'll pick it up max 30 seconds later. + // it actually simplifies a few things. I.e. If we miss a new shard + // while AWS is resharding, we'll pick it up max 30 seconds later. - // It might be worth refactoring this flow to allow the consumer to + // It might be worth refactoring this flow to allow the consumer // to notify the broker when a shard is closed. However, shards don't // necessarily close at the same time, so we could potentially get a // thundering heard of notifications from the consumer. @@ -51,8 +51,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { var ticker = time.NewTicker(30 * time.Second) for { - err := g.findNewShards(ctx, shardc) - if err != nil { + if err := g.findNewShards(ctx, shardC); err != nil { ticker.Stop() return err } @@ -66,7 +65,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { } } -func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error { +func (g *AllGroup) CloseShard(_ context.Context, shardID string) error { g.shardMu.Lock() defer g.shardMu.Unlock() c, ok := g.shardsClosed[shardID] @@ -95,7 +94,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. -func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error { +func (g *AllGroup) findNewShards(ctx context.Context, shardC chan types.Shard) error { g.shardMu.Lock() defer g.shardMu.Unlock() @@ -111,14 +110,17 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e // channels before we start using any of them. It's highly probable // that Kinesis provides us the shards in dependency order (parents // before children), but it doesn't appear to be a guarantee. + newShards := make(map[string]types.Shard) for _, shard := range shards { if _, ok := g.shards[*shard.ShardId]; ok { continue } g.shards[*shard.ShardId] = shard g.shardsClosed[*shard.ShardId] = make(chan struct{}) + newShards[*shard.ShardId] = shard } - for _, shard := range shards { + // only new shards need to be checked for parent dependencies + for _, shard := range newShards { shard := shard // Shadow shard, since we use it in goroutine var parent1, parent2 <-chan struct{} if shard.ParentShardId != nil { @@ -134,7 +136,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e // but when splits or joins happen, we need to process all parents prior // to processing children or that ordering guarantee is not maintained. if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) { - shardc <- shard + shardC <- shard } }() } diff --git a/consumer.go b/consumer.go index 80ab45c..4777356 100644 --- a/consumer.go +++ b/consumer.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "sync" "time" @@ -38,7 +38,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) { store: &noopStore{}, counter: &noopCounter{}, logger: &noopLogger{ - logger: log.New(ioutil.Discard, "", log.LstdFlags), + logger: log.New(io.Discard, "", log.LstdFlags), }, scanInterval: 250 * time.Millisecond, maxRecords: 10000, @@ -90,7 +90,7 @@ type Consumer struct { type ScanFunc func(*Record) error // ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that -// the current checkpoint should be skipped skipped. It is not returned +// the current checkpoint should be skipped. It is not returned // as an error by any function. var ErrSkipCheckpoint = errors.New("skip checkpoint") @@ -102,25 +102,35 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { defer cancel() var ( - errc = make(chan error, 1) - shardc = make(chan types.Shard, 1) + errC = make(chan error, 1) + shardC = make(chan types.Shard, 1) ) go func() { - err := c.group.Start(ctx, shardc) + err := c.group.Start(ctx, shardC) if err != nil { - errc <- fmt.Errorf("error starting scan: %w", err) + errC <- fmt.Errorf("error starting scan: %w", err) cancel() } <-ctx.Done() - close(shardc) + close(shardC) }() wg := new(sync.WaitGroup) // process each of the shards - for shard := range shardc { + shardsInProcess := make(map[string]struct{}) + for shard := range shardC { + shardId := aws.ToString(shard.ShardId) + if _, ok := shardsInProcess[shardId]; ok { + // safetynet: if shard already in process by another goroutine, just skipping the request + continue + } wg.Add(1) go func(shardID string) { + shardsInProcess[shardID] = struct{}{} + defer func() { + delete(shardsInProcess, shardID) + }() defer wg.Done() var err error if err = c.ScanShard(ctx, shardID, fn); err != nil { @@ -132,24 +142,20 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { } if err != nil { select { - case errc <- fmt.Errorf("shard %s error: %w", shardID, err): + case errC <- fmt.Errorf("shard %s error: %w", shardID, err): cancel() default: } } - }(aws.ToString(shard.ShardId)) + }(shardId) } go func() { wg.Wait() - close(errc) + close(errC) }() - return <-errc -} - -func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error { - return nil + return <-errC } // ScanShard loops over records on a specific shard, calls the callback func @@ -213,14 +219,12 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return nil default: err := fn(&Record{r, shardID, resp.MillisBehindLatest}) - if err != nil && err != ErrSkipCheckpoint { + if err != nil && !errors.Is(err, ErrSkipCheckpoint) { return err } - if err != ErrSkipCheckpoint { - if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { - return err - } + if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { + return err } c.counter.Add("records", 1) @@ -284,7 +288,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp params.Timestamp = c.initialTimestamp } else { - params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType) + params.ShardIteratorType = c.initialShardIteratorType } res, err := c.client.GetShardIterator(ctx, params)