From eaadad72e5442cfaf668d021f2ce69fdf5e0c9f0 Mon Sep 17 00:00:00 2001 From: Mikhail Date: Mon, 16 Sep 2024 15:23:47 +1000 Subject: [PATCH] bug: fix creating multiple scans for the same shard --- allgroup.go | 13 ++++++++----- consumer.go | 34 ++++++++++++++++++++++------------ 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/allgroup.go b/allgroup.go index 3374086..0823536 100644 --- a/allgroup.go +++ b/allgroup.go @@ -38,7 +38,7 @@ type AllGroup struct { // Start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { +func (g *AllGroup) Start(ctx context.Context, shardC chan types.Shard) error { // Note: while ticker is a rather naive approach to this problem, // it actually simplifies a few things. I.e. If we miss a new shard // while AWS is resharding, we'll pick it up max 30 seconds later. @@ -51,7 +51,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { var ticker = time.NewTicker(30 * time.Second) for { - if err := g.findNewShards(ctx, shardc); err != nil { + if err := g.findNewShards(ctx, shardC); err != nil { ticker.Stop() return err } @@ -94,7 +94,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. -func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error { +func (g *AllGroup) findNewShards(ctx context.Context, shardC chan types.Shard) error { g.shardMu.Lock() defer g.shardMu.Unlock() @@ -110,14 +110,17 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e // channels before we start using any of them. It's highly probable // that Kinesis provides us the shards in dependency order (parents // before children), but it doesn't appear to be a guarantee. + newShards := make(map[string]types.Shard) for _, shard := range shards { if _, ok := g.shards[*shard.ShardId]; ok { continue } g.shards[*shard.ShardId] = shard g.shardsClosed[*shard.ShardId] = make(chan struct{}) + newShards[*shard.ShardId] = shard } - for _, shard := range shards { + // only new shards need to be checked for parent dependencies + for _, shard := range newShards { shard := shard // Shadow shard, since we use it in goroutine var parent1, parent2 <-chan struct{} if shard.ParentShardId != nil { @@ -133,7 +136,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e // but when splits or joins happen, we need to process all parents prior // to processing children or that ordering guarantee is not maintained. if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) { - shardc <- shard + shardC <- shard } }() } diff --git a/consumer.go b/consumer.go index 385ff46..e28fe67 100644 --- a/consumer.go +++ b/consumer.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "sync" "time" @@ -38,7 +38,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) { store: &noopStore{}, counter: &noopCounter{}, logger: &noopLogger{ - logger: log.New(ioutil.Discard, "", log.LstdFlags), + logger: log.New(io.Discard, "", log.LstdFlags), }, scanInterval: 250 * time.Millisecond, maxRecords: 10000, @@ -102,25 +102,35 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { defer cancel() var ( - errc = make(chan error, 1) - shardc = make(chan types.Shard, 1) + errC = make(chan error, 1) + shardC = make(chan types.Shard, 1) ) go func() { - err := c.group.Start(ctx, shardc) + err := c.group.Start(ctx, shardC) if err != nil { - errc <- fmt.Errorf("error starting scan: %w", err) + errC <- fmt.Errorf("error starting scan: %w", err) cancel() } <-ctx.Done() - close(shardc) + close(shardC) }() wg := new(sync.WaitGroup) // process each of the shards - for shard := range shardc { + shardsInProcess := make(map[string]struct{}) + for shard := range shardC { + shardId := aws.ToString(shard.ShardId) + if _, ok := shardsInProcess[shardId]; ok { + // safetynet: if shard already in process by another goroutine, just skipping the request + continue + } wg.Add(1) go func(shardID string) { + shardsInProcess[shardID] = struct{}{} + defer func() { + delete(shardsInProcess, shardID) + }() defer wg.Done() var err error if err = c.ScanShard(ctx, shardID, fn); err != nil { @@ -132,20 +142,20 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { } if err != nil { select { - case errc <- fmt.Errorf("shard %s error: %w", shardID, err): + case errC <- fmt.Errorf("shard %s error: %w", shardID, err): cancel() default: } } - }(aws.ToString(shard.ShardId)) + }(shardId) } go func() { wg.Wait() - close(errc) + close(errC) }() - return <-errc + return <-errC } // ScanShard loops over records on a specific shard, calls the callback func