Sync with Upstream (#241)
* Fix ProvisionedThroughputExceededException error (#161) Fixes #158. Seems the bug was introduced in #155. See #155 (comment) * fix isRetriableError (#159) fix issues-158 * fixed concurrent map rw panic for shardsInProgress map (#163) Co-authored-by: Sanket Deshpande <sanket@clearblade.com> --------- Co-authored-by: Mikhail Konovalov <4463812+mskonovalov@users.noreply.github.com> Co-authored-by: lrs <82623629@qq.com> Co-authored-by: Sanket Deshpande <ssd20072@gmail.com> Co-authored-by: Sanket Deshpande <sanket@clearblade.com>
This commit is contained in:
parent
f0b82db6ac
commit
0ea3954331
2 changed files with 142 additions and 26 deletions
96
allgroup.go
96
allgroup.go
|
|
@ -2,6 +2,7 @@ package consumer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -13,11 +14,12 @@ import (
|
||||||
// all shards on a stream
|
// all shards on a stream
|
||||||
func NewAllGroup(kinesis kinesisClient, store Store, streamName string, logger *slog.Logger) *AllGroup {
|
func NewAllGroup(kinesis kinesisClient, store Store, streamName string, logger *slog.Logger) *AllGroup {
|
||||||
return &AllGroup{
|
return &AllGroup{
|
||||||
kinesis: kinesis,
|
kinesis: kinesis,
|
||||||
shards: make(map[string]types.Shard),
|
shards: make(map[string]types.Shard),
|
||||||
streamName: streamName,
|
shardsClosed: make(map[string]chan struct{}),
|
||||||
slog: logger,
|
streamName: streamName,
|
||||||
Store: store,
|
slog: logger,
|
||||||
|
Store: store,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -30,56 +32,114 @@ type AllGroup struct {
|
||||||
slog *slog.Logger
|
slog *slog.Logger
|
||||||
Store
|
Store
|
||||||
|
|
||||||
shardMu sync.Mutex
|
shardMu sync.Mutex
|
||||||
shards map[string]types.Shard
|
shards map[string]types.Shard
|
||||||
|
shardsClosed map[string]chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start is a blocking operation which will loop and attempt to find new
|
// Start is a blocking operation which will loop and attempt to find new
|
||||||
// shards on a regular cadence.
|
// shards on a regular cadence.
|
||||||
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
|
func (g *AllGroup) Start(ctx context.Context, shardC chan types.Shard) error {
|
||||||
// Note: while ticker is a rather naive approach to this problem,
|
// Note: while ticker is a rather naive approach to this problem,
|
||||||
// it actually simplifies a few things. i.e. If we miss a new shard
|
// it actually simplifies a few things. I.e. If we miss a new shard
|
||||||
// while AWS is re-sharding we'll pick it up max 30 seconds later.
|
// while AWS is re-sharding, we'll pick it up max 30 seconds later.
|
||||||
|
|
||||||
// It might be worth refactoring this flow to allow the consumer to
|
// It might be worth refactoring this flow to allow the consumer
|
||||||
// notify the broker when a shard is closed. However, shards don't
|
// to notify the broker when a shard is closed. However, shards don't
|
||||||
// necessarily close at the same time, so we could potentially get a
|
// necessarily close at the same time, so we could potentially get a
|
||||||
// thundering heard of notifications from the consumer.
|
// thundering heard of notifications from the consumer.
|
||||||
|
|
||||||
var ticker = time.NewTicker(30 * time.Second)
|
var ticker = time.NewTicker(30 * time.Second)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
g.findNewShards(ctx, shardc)
|
if err := g.findNewShards(ctx, shardC); err != nil {
|
||||||
|
ticker.Stop()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return nil
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *AllGroup) CloseShard(_ context.Context, shardID string) error {
|
||||||
|
g.shardMu.Lock()
|
||||||
|
defer g.shardMu.Unlock()
|
||||||
|
c, ok := g.shardsClosed[shardID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("closing unknown shard ID %q", shardID)
|
||||||
|
}
|
||||||
|
close(c)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
|
||||||
|
if c == nil {
|
||||||
|
// no channel means we haven't seen this shard in listShards, so it
|
||||||
|
// probably fell off the TRIM_HORIZON, and we can assume it's fully processed.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-c:
|
||||||
|
// the channel has been processed and closed by the consumer (CloseShard has been called)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// findNewShards pulls the list of shards from the Kinesis API
|
// findNewShards pulls the list of shards from the Kinesis API
|
||||||
// and uses a local cache to determine if we are already processing
|
// and uses a local cache to determine if we are already processing
|
||||||
// a particular shard.
|
// a particular shard.
|
||||||
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
|
func (g *AllGroup) findNewShards(ctx context.Context, shardC chan types.Shard) error {
|
||||||
g.shardMu.Lock()
|
g.shardMu.Lock()
|
||||||
defer g.shardMu.Unlock()
|
defer g.shardMu.Unlock()
|
||||||
|
|
||||||
g.slog.DebugContext(ctx, "fetch shards")
|
g.slog.DebugContext(ctx, "fetching shards")
|
||||||
|
|
||||||
shards, err := listShards(ctx, g.kinesis, g.streamName)
|
shards, err := listShards(ctx, g.kinesis, g.streamName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error()))
|
g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error()))
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We do two `for` loops, since we have to set up all the `shardClosed`
|
||||||
|
// channels before we start using any of them. It's highly probable
|
||||||
|
// that Kinesis provides us the shards in dependency order (parents
|
||||||
|
// before children), but it doesn't appear to be a guarantee.
|
||||||
|
newShards := make(map[string]types.Shard)
|
||||||
for _, shard := range shards {
|
for _, shard := range shards {
|
||||||
if _, ok := g.shards[*shard.ShardId]; ok {
|
if _, ok := g.shards[*shard.ShardId]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
g.shards[*shard.ShardId] = shard
|
g.shards[*shard.ShardId] = shard
|
||||||
shardc <- shard
|
g.shardsClosed[*shard.ShardId] = make(chan struct{})
|
||||||
|
newShards[*shard.ShardId] = shard
|
||||||
}
|
}
|
||||||
|
// only new shards need to be checked for parent dependencies
|
||||||
|
for _, shard := range newShards {
|
||||||
|
shard := shard // Shadow shard, since we use it in goroutine
|
||||||
|
var parent1, parent2 <-chan struct{}
|
||||||
|
if shard.ParentShardId != nil {
|
||||||
|
parent1 = g.shardsClosed[*shard.ParentShardId]
|
||||||
|
}
|
||||||
|
if shard.AdjacentParentShardId != nil {
|
||||||
|
parent2 = g.shardsClosed[*shard.AdjacentParentShardId]
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
// Asynchronously wait for all parents of this shard to be processed
|
||||||
|
// before providing it out to our client. Kinesis guarantees that a
|
||||||
|
// given partition key's data will be provided to clients in-order,
|
||||||
|
// but when splits or joins happen, we need to process all parents prior
|
||||||
|
// to processing children or that ordering guarantee is not maintained.
|
||||||
|
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
|
||||||
|
shardC <- shard
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
72
consumer.go
72
consumer.go
|
|
@ -46,6 +46,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
maxRecords: 10000,
|
maxRecords: 10000,
|
||||||
metricRegistry: nil,
|
metricRegistry: nil,
|
||||||
numWorkers: 1,
|
numWorkers: 1,
|
||||||
|
logger: &noopLogger{
|
||||||
|
logger: log.New(io.Discard, "", log.LstdFlags),
|
||||||
|
},
|
||||||
|
scanInterval: 250 * time.Millisecond,
|
||||||
|
maxRecords: 10000,
|
||||||
}
|
}
|
||||||
|
|
||||||
// override defaults
|
// override defaults
|
||||||
|
|
@ -128,18 +133,40 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
c.group.Start(ctx, shardC)
|
err := c.group.Start(ctx, shardC)
|
||||||
|
if err != nil {
|
||||||
|
errC <- fmt.Errorf("error starting scan: %w", err)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
close(shardC)
|
close(shardC)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg := new(sync.WaitGroup)
|
wg := new(sync.WaitGroup)
|
||||||
// process each of the shards
|
// process each of the shards
|
||||||
|
s := newShardsInProcess()
|
||||||
for shard := range shardC {
|
for shard := range shardC {
|
||||||
|
shardId := aws.ToString(shard.ShardId)
|
||||||
|
if s.doesShardExist(shardId) {
|
||||||
|
// safetynet: if shard already in process by another goroutine, just skipping the request
|
||||||
|
continue
|
||||||
|
}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(shardID string) {
|
go func(shardID string) {
|
||||||
|
s.addShard(shardID)
|
||||||
|
defer func() {
|
||||||
|
s.deleteShard(shardID)
|
||||||
|
}()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := c.ScanShard(ctx, shardID, fn); err != nil {
|
var err error
|
||||||
|
if err = c.ScanShard(ctx, shardID, fn); err != nil {
|
||||||
|
err = fmt.Errorf("shard %s error: %w", shardID, err)
|
||||||
|
} else if closeable, ok := c.group.(CloseableGroup); !ok {
|
||||||
|
// group doesn't allow closure, skip calling CloseShard
|
||||||
|
} else if err = closeable.CloseShard(ctx, shardID); err != nil {
|
||||||
|
err = fmt.Errorf("shard closed CloseableGroup error: %w", err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
select {
|
select {
|
||||||
case errC <- fmt.Errorf("shard %s error: %w", shardID, err):
|
case errC <- fmt.Errorf("shard %s error: %w", shardID, err):
|
||||||
// first error to occur
|
// first error to occur
|
||||||
|
|
@ -148,7 +175,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
// error has already occurred
|
// error has already occurred
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(aws.ToString(shard.ShardId))
|
}(shardId)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|
@ -353,12 +380,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRetriableError(err error) bool {
|
func isRetriableError(err error) bool {
|
||||||
var expiredIteratorException *types.ExpiredIteratorException
|
if oe := (*types.ExpiredIteratorException)(nil); errors.As(err, &oe) {
|
||||||
var provisionedThroughputExceededException *types.ProvisionedThroughputExceededException
|
|
||||||
switch {
|
|
||||||
case errors.As(err, &expiredIteratorException):
|
|
||||||
return true
|
return true
|
||||||
case errors.As(err, &provisionedThroughputExceededException):
|
}
|
||||||
|
if oe := (*types.ProvisionedThroughputExceededException)(nil); errors.As(err, &oe) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|
@ -367,3 +392,34 @@ func isRetriableError(err error) bool {
|
||||||
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
|
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
|
||||||
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type shards struct {
|
||||||
|
*sync.RWMutex
|
||||||
|
shardsInProcess map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newShardsInProcess() *shards {
|
||||||
|
return &shards{
|
||||||
|
RWMutex: &sync.RWMutex{},
|
||||||
|
shardsInProcess: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *shards) addShard(shardId string) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
s.shardsInProcess[shardId] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *shards) doesShardExist(shardId string) bool {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
_, ok := s.shardsInProcess[shardId]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *shards) deleteShard(shardId string) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
delete(s.shardsInProcess, shardId)
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue