Sync with Upstream (#241)

* Fix ProvisionedThroughputExceededException error (#161)

Fixes #158. Seems the bug was introduced in #155. See #155 (comment)

* fix isRetriableError (#159)

fix issues-158

* fixed concurrent map rw panic for shardsInProgress map (#163)

Co-authored-by: Sanket Deshpande <sanket@clearblade.com>

---------

Co-authored-by: Mikhail Konovalov <4463812+mskonovalov@users.noreply.github.com>
Co-authored-by: lrs <82623629@qq.com>
Co-authored-by: Sanket Deshpande <ssd20072@gmail.com>
Co-authored-by: Sanket Deshpande <sanket@clearblade.com>
This commit is contained in:
Alex 2024-10-02 10:03:03 +02:00 committed by GitHub
parent f0b82db6ac
commit 0ea3954331
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 142 additions and 26 deletions

View file

@ -2,6 +2,7 @@ package consumer
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
@ -13,11 +14,12 @@ import (
// all shards on a stream
func NewAllGroup(kinesis kinesisClient, store Store, streamName string, logger *slog.Logger) *AllGroup {
return &AllGroup{
kinesis: kinesis,
shards: make(map[string]types.Shard),
streamName: streamName,
slog: logger,
Store: store,
kinesis: kinesis,
shards: make(map[string]types.Shard),
shardsClosed: make(map[string]chan struct{}),
streamName: streamName,
slog: logger,
Store: store,
}
}
@ -30,56 +32,114 @@ type AllGroup struct {
slog *slog.Logger
Store
shardMu sync.Mutex
shards map[string]types.Shard
shardMu sync.Mutex
shards map[string]types.Shard
shardsClosed map[string]chan struct{}
}
// Start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
func (g *AllGroup) Start(ctx context.Context, shardC chan types.Shard) error {
// Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. i.e. If we miss a new shard
// while AWS is re-sharding we'll pick it up max 30 seconds later.
// it actually simplifies a few things. I.e. If we miss a new shard
// while AWS is re-sharding, we'll pick it up max 30 seconds later.
// It might be worth refactoring this flow to allow the consumer to
// notify the broker when a shard is closed. However, shards don't
// It might be worth refactoring this flow to allow the consumer
// to notify the broker when a shard is closed. However, shards don't
// necessarily close at the same time, so we could potentially get a
// thundering heard of notifications from the consumer.
var ticker = time.NewTicker(30 * time.Second)
for {
g.findNewShards(ctx, shardc)
if err := g.findNewShards(ctx, shardC); err != nil {
ticker.Stop()
return err
}
select {
case <-ctx.Done():
ticker.Stop()
return
return nil
case <-ticker.C:
}
}
}
func (g *AllGroup) CloseShard(_ context.Context, shardID string) error {
g.shardMu.Lock()
defer g.shardMu.Unlock()
c, ok := g.shardsClosed[shardID]
if !ok {
return fmt.Errorf("closing unknown shard ID %q", shardID)
}
close(c)
return nil
}
func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
if c == nil {
// no channel means we haven't seen this shard in listShards, so it
// probably fell off the TRIM_HORIZON, and we can assume it's fully processed.
return true
}
select {
case <-ctx.Done():
return false
case <-c:
// the channel has been processed and closed by the consumer (CloseShard has been called)
return true
}
}
// findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing
// a particular shard.
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
func (g *AllGroup) findNewShards(ctx context.Context, shardC chan types.Shard) error {
g.shardMu.Lock()
defer g.shardMu.Unlock()
g.slog.DebugContext(ctx, "fetch shards")
g.slog.DebugContext(ctx, "fetching shards")
shards, err := listShards(ctx, g.kinesis, g.streamName)
if err != nil {
g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error()))
return
return err
}
// We do two `for` loops, since we have to set up all the `shardClosed`
// channels before we start using any of them. It's highly probable
// that Kinesis provides us the shards in dependency order (parents
// before children), but it doesn't appear to be a guarantee.
newShards := make(map[string]types.Shard)
for _, shard := range shards {
if _, ok := g.shards[*shard.ShardId]; ok {
continue
}
g.shards[*shard.ShardId] = shard
shardc <- shard
g.shardsClosed[*shard.ShardId] = make(chan struct{})
newShards[*shard.ShardId] = shard
}
// only new shards need to be checked for parent dependencies
for _, shard := range newShards {
shard := shard // Shadow shard, since we use it in goroutine
var parent1, parent2 <-chan struct{}
if shard.ParentShardId != nil {
parent1 = g.shardsClosed[*shard.ParentShardId]
}
if shard.AdjacentParentShardId != nil {
parent2 = g.shardsClosed[*shard.AdjacentParentShardId]
}
go func() {
// Asynchronously wait for all parents of this shard to be processed
// before providing it out to our client. Kinesis guarantees that a
// given partition key's data will be provided to clients in-order,
// but when splits or joins happen, we need to process all parents prior
// to processing children or that ordering guarantee is not maintained.
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
shardC <- shard
}
}()
}
return nil
}

View file

@ -46,6 +46,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
maxRecords: 10000,
metricRegistry: nil,
numWorkers: 1,
logger: &noopLogger{
logger: log.New(io.Discard, "", log.LstdFlags),
},
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
}
// override defaults
@ -128,18 +133,40 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
)
go func() {
c.group.Start(ctx, shardC)
err := c.group.Start(ctx, shardC)
if err != nil {
errC <- fmt.Errorf("error starting scan: %w", err)
cancel()
}
<-ctx.Done()
close(shardC)
}()
wg := new(sync.WaitGroup)
// process each of the shards
s := newShardsInProcess()
for shard := range shardC {
shardId := aws.ToString(shard.ShardId)
if s.doesShardExist(shardId) {
// safetynet: if shard already in process by another goroutine, just skipping the request
continue
}
wg.Add(1)
go func(shardID string) {
s.addShard(shardID)
defer func() {
s.deleteShard(shardID)
}()
defer wg.Done()
if err := c.ScanShard(ctx, shardID, fn); err != nil {
var err error
if err = c.ScanShard(ctx, shardID, fn); err != nil {
err = fmt.Errorf("shard %s error: %w", shardID, err)
} else if closeable, ok := c.group.(CloseableGroup); !ok {
// group doesn't allow closure, skip calling CloseShard
} else if err = closeable.CloseShard(ctx, shardID); err != nil {
err = fmt.Errorf("shard closed CloseableGroup error: %w", err)
}
if err != nil {
select {
case errC <- fmt.Errorf("shard %s error: %w", shardID, err):
// first error to occur
@ -148,7 +175,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// error has already occurred
}
}
}(aws.ToString(shard.ShardId))
}(shardId)
}
go func() {
@ -353,12 +380,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
}
func isRetriableError(err error) bool {
var expiredIteratorException *types.ExpiredIteratorException
var provisionedThroughputExceededException *types.ProvisionedThroughputExceededException
switch {
case errors.As(err, &expiredIteratorException):
if oe := (*types.ExpiredIteratorException)(nil); errors.As(err, &oe) {
return true
case errors.As(err, &provisionedThroughputExceededException):
}
if oe := (*types.ProvisionedThroughputExceededException)(nil); errors.As(err, &oe) {
return true
}
return false
@ -367,3 +392,34 @@ func isRetriableError(err error) bool {
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}
type shards struct {
*sync.RWMutex
shardsInProcess map[string]struct{}
}
func newShardsInProcess() *shards {
return &shards{
RWMutex: &sync.RWMutex{},
shardsInProcess: make(map[string]struct{}),
}
}
func (s *shards) addShard(shardId string) {
s.Lock()
defer s.Unlock()
s.shardsInProcess[shardId] = struct{}{}
}
func (s *shards) doesShardExist(shardId string) bool {
s.RLock()
defer s.RUnlock()
_, ok := s.shardsInProcess[shardId]
return ok
}
func (s *shards) deleteShard(shardId string) {
s.Lock()
defer s.Unlock()
delete(s.shardsInProcess, shardId)
}