From 0ea39543310d6516aac207de269ad20ed450418c Mon Sep 17 00:00:00 2001 From: Alex <116081750+alexgridx@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:03:03 +0200 Subject: [PATCH] Sync with Upstream (#241) * Fix ProvisionedThroughputExceededException error (#161) Fixes #158. Seems the bug was introduced in #155. See #155 (comment) * fix isRetriableError (#159) fix issues-158 * fixed concurrent map rw panic for shardsInProgress map (#163) Co-authored-by: Sanket Deshpande --------- Co-authored-by: Mikhail Konovalov <4463812+mskonovalov@users.noreply.github.com> Co-authored-by: lrs <82623629@qq.com> Co-authored-by: Sanket Deshpande Co-authored-by: Sanket Deshpande --- allgroup.go | 96 +++++++++++++++++++++++++++++++++++++++++++---------- consumer.go | 72 +++++++++++++++++++++++++++++++++++----- 2 files changed, 142 insertions(+), 26 deletions(-) diff --git a/allgroup.go b/allgroup.go index 328ab0d..80f3e93 100644 --- a/allgroup.go +++ b/allgroup.go @@ -2,6 +2,7 @@ package consumer import ( "context" + "fmt" "log/slog" "sync" "time" @@ -13,11 +14,12 @@ import ( // all shards on a stream func NewAllGroup(kinesis kinesisClient, store Store, streamName string, logger *slog.Logger) *AllGroup { return &AllGroup{ - kinesis: kinesis, - shards: make(map[string]types.Shard), - streamName: streamName, - slog: logger, - Store: store, + kinesis: kinesis, + shards: make(map[string]types.Shard), + shardsClosed: make(map[string]chan struct{}), + streamName: streamName, + slog: logger, + Store: store, } } @@ -30,56 +32,114 @@ type AllGroup struct { slog *slog.Logger Store - shardMu sync.Mutex - shards map[string]types.Shard + shardMu sync.Mutex + shards map[string]types.Shard + shardsClosed map[string]chan struct{} } // Start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { +func (g *AllGroup) Start(ctx context.Context, shardC chan types.Shard) error { // Note: while ticker is a rather naive approach to this problem, - // it actually simplifies a few things. i.e. If we miss a new shard - // while AWS is re-sharding we'll pick it up max 30 seconds later. + // it actually simplifies a few things. I.e. If we miss a new shard + // while AWS is re-sharding, we'll pick it up max 30 seconds later. - // It might be worth refactoring this flow to allow the consumer to - // notify the broker when a shard is closed. However, shards don't + // It might be worth refactoring this flow to allow the consumer + // to notify the broker when a shard is closed. However, shards don't // necessarily close at the same time, so we could potentially get a // thundering heard of notifications from the consumer. var ticker = time.NewTicker(30 * time.Second) for { - g.findNewShards(ctx, shardc) + if err := g.findNewShards(ctx, shardC); err != nil { + ticker.Stop() + return err + } select { case <-ctx.Done(): ticker.Stop() - return + return nil case <-ticker.C: } } } +func (g *AllGroup) CloseShard(_ context.Context, shardID string) error { + g.shardMu.Lock() + defer g.shardMu.Unlock() + c, ok := g.shardsClosed[shardID] + if !ok { + return fmt.Errorf("closing unknown shard ID %q", shardID) + } + close(c) + return nil +} + +func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { + if c == nil { + // no channel means we haven't seen this shard in listShards, so it + // probably fell off the TRIM_HORIZON, and we can assume it's fully processed. + return true + } + select { + case <-ctx.Done(): + return false + case <-c: + // the channel has been processed and closed by the consumer (CloseShard has been called) + return true + } +} + // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. -func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { +func (g *AllGroup) findNewShards(ctx context.Context, shardC chan types.Shard) error { g.shardMu.Lock() defer g.shardMu.Unlock() - g.slog.DebugContext(ctx, "fetch shards") + g.slog.DebugContext(ctx, "fetching shards") shards, err := listShards(ctx, g.kinesis, g.streamName) if err != nil { g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error())) - return + return err } + // We do two `for` loops, since we have to set up all the `shardClosed` + // channels before we start using any of them. It's highly probable + // that Kinesis provides us the shards in dependency order (parents + // before children), but it doesn't appear to be a guarantee. + newShards := make(map[string]types.Shard) for _, shard := range shards { if _, ok := g.shards[*shard.ShardId]; ok { continue } g.shards[*shard.ShardId] = shard - shardc <- shard + g.shardsClosed[*shard.ShardId] = make(chan struct{}) + newShards[*shard.ShardId] = shard } + // only new shards need to be checked for parent dependencies + for _, shard := range newShards { + shard := shard // Shadow shard, since we use it in goroutine + var parent1, parent2 <-chan struct{} + if shard.ParentShardId != nil { + parent1 = g.shardsClosed[*shard.ParentShardId] + } + if shard.AdjacentParentShardId != nil { + parent2 = g.shardsClosed[*shard.AdjacentParentShardId] + } + go func() { + // Asynchronously wait for all parents of this shard to be processed + // before providing it out to our client. Kinesis guarantees that a + // given partition key's data will be provided to clients in-order, + // but when splits or joins happen, we need to process all parents prior + // to processing children or that ordering guarantee is not maintained. + if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) { + shardC <- shard + } + }() + } + return nil } diff --git a/consumer.go b/consumer.go index 8c3836a..d0be2dc 100644 --- a/consumer.go +++ b/consumer.go @@ -46,6 +46,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) { maxRecords: 10000, metricRegistry: nil, numWorkers: 1, + logger: &noopLogger{ + logger: log.New(io.Discard, "", log.LstdFlags), + }, + scanInterval: 250 * time.Millisecond, + maxRecords: 10000, } // override defaults @@ -128,18 +133,40 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { ) go func() { - c.group.Start(ctx, shardC) + err := c.group.Start(ctx, shardC) + if err != nil { + errC <- fmt.Errorf("error starting scan: %w", err) + cancel() + } <-ctx.Done() close(shardC) }() wg := new(sync.WaitGroup) // process each of the shards + s := newShardsInProcess() for shard := range shardC { + shardId := aws.ToString(shard.ShardId) + if s.doesShardExist(shardId) { + // safetynet: if shard already in process by another goroutine, just skipping the request + continue + } wg.Add(1) go func(shardID string) { + s.addShard(shardID) + defer func() { + s.deleteShard(shardID) + }() defer wg.Done() - if err := c.ScanShard(ctx, shardID, fn); err != nil { + var err error + if err = c.ScanShard(ctx, shardID, fn); err != nil { + err = fmt.Errorf("shard %s error: %w", shardID, err) + } else if closeable, ok := c.group.(CloseableGroup); !ok { + // group doesn't allow closure, skip calling CloseShard + } else if err = closeable.CloseShard(ctx, shardID); err != nil { + err = fmt.Errorf("shard closed CloseableGroup error: %w", err) + } + if err != nil { select { case errC <- fmt.Errorf("shard %s error: %w", shardID, err): // first error to occur @@ -148,7 +175,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // error has already occurred } } - }(aws.ToString(shard.ShardId)) + }(shardId) } go func() { @@ -353,12 +380,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } func isRetriableError(err error) bool { - var expiredIteratorException *types.ExpiredIteratorException - var provisionedThroughputExceededException *types.ProvisionedThroughputExceededException - switch { - case errors.As(err, &expiredIteratorException): + if oe := (*types.ExpiredIteratorException)(nil); errors.As(err, &oe) { return true - case errors.As(err, &provisionedThroughputExceededException): + } + if oe := (*types.ProvisionedThroughputExceededException)(nil); errors.As(err, &oe) { return true } return false @@ -367,3 +392,34 @@ func isRetriableError(err error) bool { func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } + +type shards struct { + *sync.RWMutex + shardsInProcess map[string]struct{} +} + +func newShardsInProcess() *shards { + return &shards{ + RWMutex: &sync.RWMutex{}, + shardsInProcess: make(map[string]struct{}), + } +} + +func (s *shards) addShard(shardId string) { + s.Lock() + defer s.Unlock() + s.shardsInProcess[shardId] = struct{}{} +} + +func (s *shards) doesShardExist(shardId string) bool { + s.RLock() + defer s.RUnlock() + _, ok := s.shardsInProcess[shardId] + return ok +} + +func (s *shards) deleteShard(shardId string) { + s.Lock() + defer s.Unlock() + delete(s.shardsInProcess, shardId) +}