bug: fix creating multiple scans for the same shard

This commit is contained in:
Mikhail 2024-09-16 15:23:47 +10:00
parent 6fdb1209b5
commit eaadad72e5
No known key found for this signature in database
GPG key ID: 6FFFEA01DBC79BFC
2 changed files with 30 additions and 17 deletions

View file

@ -38,7 +38,7 @@ type AllGroup struct {
// Start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
func (g *AllGroup) Start(ctx context.Context, shardC chan types.Shard) error {
// Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. I.e. If we miss a new shard
// while AWS is resharding, we'll pick it up max 30 seconds later.
@ -51,7 +51,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
var ticker = time.NewTicker(30 * time.Second)
for {
if err := g.findNewShards(ctx, shardc); err != nil {
if err := g.findNewShards(ctx, shardC); err != nil {
ticker.Stop()
return err
}
@ -94,7 +94,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
// findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing
// a particular shard.
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error {
func (g *AllGroup) findNewShards(ctx context.Context, shardC chan types.Shard) error {
g.shardMu.Lock()
defer g.shardMu.Unlock()
@ -110,14 +110,17 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e
// channels before we start using any of them. It's highly probable
// that Kinesis provides us the shards in dependency order (parents
// before children), but it doesn't appear to be a guarantee.
newShards := make(map[string]types.Shard)
for _, shard := range shards {
if _, ok := g.shards[*shard.ShardId]; ok {
continue
}
g.shards[*shard.ShardId] = shard
g.shardsClosed[*shard.ShardId] = make(chan struct{})
newShards[*shard.ShardId] = shard
}
for _, shard := range shards {
// only new shards need to be checked for parent dependencies
for _, shard := range newShards {
shard := shard // Shadow shard, since we use it in goroutine
var parent1, parent2 <-chan struct{}
if shard.ParentShardId != nil {
@ -133,7 +136,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e
// but when splits or joins happen, we need to process all parents prior
// to processing children or that ordering guarantee is not maintained.
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
shardc <- shard
shardC <- shard
}
}()
}

View file

@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"io"
"log"
"sync"
"time"
@ -38,7 +38,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
store: &noopStore{},
counter: &noopCounter{},
logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
logger: log.New(io.Discard, "", log.LstdFlags),
},
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
@ -102,25 +102,35 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
defer cancel()
var (
errc = make(chan error, 1)
shardc = make(chan types.Shard, 1)
errC = make(chan error, 1)
shardC = make(chan types.Shard, 1)
)
go func() {
err := c.group.Start(ctx, shardc)
err := c.group.Start(ctx, shardC)
if err != nil {
errc <- fmt.Errorf("error starting scan: %w", err)
errC <- fmt.Errorf("error starting scan: %w", err)
cancel()
}
<-ctx.Done()
close(shardc)
close(shardC)
}()
wg := new(sync.WaitGroup)
// process each of the shards
for shard := range shardc {
shardsInProcess := make(map[string]struct{})
for shard := range shardC {
shardId := aws.ToString(shard.ShardId)
if _, ok := shardsInProcess[shardId]; ok {
// safetynet: if shard already in process by another goroutine, just skipping the request
continue
}
wg.Add(1)
go func(shardID string) {
shardsInProcess[shardID] = struct{}{}
defer func() {
delete(shardsInProcess, shardID)
}()
defer wg.Done()
var err error
if err = c.ScanShard(ctx, shardID, fn); err != nil {
@ -132,20 +142,20 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
}
if err != nil {
select {
case errc <- fmt.Errorf("shard %s error: %w", shardID, err):
case errC <- fmt.Errorf("shard %s error: %w", shardID, err):
cancel()
default:
}
}
}(aws.ToString(shard.ShardId))
}(shardId)
}
go func() {
wg.Wait()
close(errc)
close(errC)
}()
return <-errc
return <-errC
}
// ScanShard loops over records on a specific shard, calls the callback func