Fix ProvisionedThroughputExceededException error (#161)

Fixes #158. Seems the bug was introduced in #155. See #155 (comment)
This commit is contained in:
Mikhail Konovalov 2024-09-17 05:25:49 +10:00 committed by GitHub
parent 553e2392fd
commit 8d10ac8dac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 40 additions and 34 deletions

View file

@ -9,7 +9,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/aws/aws-sdk-go-v2/service/kinesis/types"
) )
// NewAllGroup returns an intitialized AllGroup for consuming // NewAllGroup returns an initialized AllGroup for consuming
// all shards on a stream // all shards on a stream
func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup {
return &AllGroup{ return &AllGroup{
@ -38,12 +38,12 @@ type AllGroup struct {
// Start is a blocking operation which will loop and attempt to find new // Start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence. // shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { func (g *AllGroup) Start(ctx context.Context, shardC chan types.Shard) error {
// Note: while ticker is a rather naive approach to this problem, // Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. i.e. If we miss a new shard // it actually simplifies a few things. I.e. If we miss a new shard
// while AWS is resharding we'll pick it up max 30 seconds later. // while AWS is resharding, we'll pick it up max 30 seconds later.
// It might be worth refactoring this flow to allow the consumer to // It might be worth refactoring this flow to allow the consumer
// to notify the broker when a shard is closed. However, shards don't // to notify the broker when a shard is closed. However, shards don't
// necessarily close at the same time, so we could potentially get a // necessarily close at the same time, so we could potentially get a
// thundering heard of notifications from the consumer. // thundering heard of notifications from the consumer.
@ -51,8 +51,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
var ticker = time.NewTicker(30 * time.Second) var ticker = time.NewTicker(30 * time.Second)
for { for {
err := g.findNewShards(ctx, shardc) if err := g.findNewShards(ctx, shardC); err != nil {
if err != nil {
ticker.Stop() ticker.Stop()
return err return err
} }
@ -66,7 +65,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
} }
} }
func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error { func (g *AllGroup) CloseShard(_ context.Context, shardID string) error {
g.shardMu.Lock() g.shardMu.Lock()
defer g.shardMu.Unlock() defer g.shardMu.Unlock()
c, ok := g.shardsClosed[shardID] c, ok := g.shardsClosed[shardID]
@ -95,7 +94,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
// findNewShards pulls the list of shards from the Kinesis API // findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing // and uses a local cache to determine if we are already processing
// a particular shard. // a particular shard.
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error { func (g *AllGroup) findNewShards(ctx context.Context, shardC chan types.Shard) error {
g.shardMu.Lock() g.shardMu.Lock()
defer g.shardMu.Unlock() defer g.shardMu.Unlock()
@ -111,14 +110,17 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e
// channels before we start using any of them. It's highly probable // channels before we start using any of them. It's highly probable
// that Kinesis provides us the shards in dependency order (parents // that Kinesis provides us the shards in dependency order (parents
// before children), but it doesn't appear to be a guarantee. // before children), but it doesn't appear to be a guarantee.
newShards := make(map[string]types.Shard)
for _, shard := range shards { for _, shard := range shards {
if _, ok := g.shards[*shard.ShardId]; ok { if _, ok := g.shards[*shard.ShardId]; ok {
continue continue
} }
g.shards[*shard.ShardId] = shard g.shards[*shard.ShardId] = shard
g.shardsClosed[*shard.ShardId] = make(chan struct{}) g.shardsClosed[*shard.ShardId] = make(chan struct{})
newShards[*shard.ShardId] = shard
} }
for _, shard := range shards { // only new shards need to be checked for parent dependencies
for _, shard := range newShards {
shard := shard // Shadow shard, since we use it in goroutine shard := shard // Shadow shard, since we use it in goroutine
var parent1, parent2 <-chan struct{} var parent1, parent2 <-chan struct{}
if shard.ParentShardId != nil { if shard.ParentShardId != nil {
@ -134,7 +136,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e
// but when splits or joins happen, we need to process all parents prior // but when splits or joins happen, we need to process all parents prior
// to processing children or that ordering guarantee is not maintained. // to processing children or that ordering guarantee is not maintained.
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) { if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
shardc <- shard shardC <- shard
} }
}() }()
} }

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"log" "log"
"sync" "sync"
"time" "time"
@ -38,7 +38,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
store: &noopStore{}, store: &noopStore{},
counter: &noopCounter{}, counter: &noopCounter{},
logger: &noopLogger{ logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags), logger: log.New(io.Discard, "", log.LstdFlags),
}, },
scanInterval: 250 * time.Millisecond, scanInterval: 250 * time.Millisecond,
maxRecords: 10000, maxRecords: 10000,
@ -90,7 +90,7 @@ type Consumer struct {
type ScanFunc func(*Record) error type ScanFunc func(*Record) error
// ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that // ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that
// the current checkpoint should be skipped skipped. It is not returned // the current checkpoint should be skipped. It is not returned
// as an error by any function. // as an error by any function.
var ErrSkipCheckpoint = errors.New("skip checkpoint") var ErrSkipCheckpoint = errors.New("skip checkpoint")
@ -102,25 +102,35 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
defer cancel() defer cancel()
var ( var (
errc = make(chan error, 1) errC = make(chan error, 1)
shardc = make(chan types.Shard, 1) shardC = make(chan types.Shard, 1)
) )
go func() { go func() {
err := c.group.Start(ctx, shardc) err := c.group.Start(ctx, shardC)
if err != nil { if err != nil {
errc <- fmt.Errorf("error starting scan: %w", err) errC <- fmt.Errorf("error starting scan: %w", err)
cancel() cancel()
} }
<-ctx.Done() <-ctx.Done()
close(shardc) close(shardC)
}() }()
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
// process each of the shards // process each of the shards
for shard := range shardc { shardsInProcess := make(map[string]struct{})
for shard := range shardC {
shardId := aws.ToString(shard.ShardId)
if _, ok := shardsInProcess[shardId]; ok {
// safetynet: if shard already in process by another goroutine, just skipping the request
continue
}
wg.Add(1) wg.Add(1)
go func(shardID string) { go func(shardID string) {
shardsInProcess[shardID] = struct{}{}
defer func() {
delete(shardsInProcess, shardID)
}()
defer wg.Done() defer wg.Done()
var err error var err error
if err = c.ScanShard(ctx, shardID, fn); err != nil { if err = c.ScanShard(ctx, shardID, fn); err != nil {
@ -132,24 +142,20 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
} }
if err != nil { if err != nil {
select { select {
case errc <- fmt.Errorf("shard %s error: %w", shardID, err): case errC <- fmt.Errorf("shard %s error: %w", shardID, err):
cancel() cancel()
default: default:
} }
} }
}(aws.ToString(shard.ShardId)) }(shardId)
} }
go func() { go func() {
wg.Wait() wg.Wait()
close(errc) close(errC)
}() }()
return <-errc return <-errC
}
func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error {
return nil
} }
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func
@ -213,15 +219,13 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
return nil return nil
default: default:
err := fn(&Record{r, shardID, resp.MillisBehindLatest}) err := fn(&Record{r, shardID, resp.MillisBehindLatest})
if err != nil && err != ErrSkipCheckpoint { if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
return err return err
} }
if err != ErrSkipCheckpoint {
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err return err
} }
}
c.counter.Add("records", 1) c.counter.Add("records", 1)
lastSeqNum = *r.SequenceNumber lastSeqNum = *r.SequenceNumber
@ -284,7 +288,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp
params.Timestamp = c.initialTimestamp params.Timestamp = c.initialTimestamp
} else { } else {
params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType) params.ShardIteratorType = c.initialShardIteratorType
} }
res, err := c.client.GetShardIterator(ctx, params) res, err := c.client.GetShardIterator(ctx, params)