bug: fix creating multiple scans for the same shard
This commit is contained in:
parent
6fdb1209b5
commit
eaadad72e5
2 changed files with 30 additions and 17 deletions
13
allgroup.go
13
allgroup.go
|
|
@ -38,7 +38,7 @@ type AllGroup struct {
|
|||
|
||||
// Start is a blocking operation which will loop and attempt to find new
|
||||
// shards on a regular cadence.
|
||||
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
|
||||
func (g *AllGroup) Start(ctx context.Context, shardC chan types.Shard) error {
|
||||
// Note: while ticker is a rather naive approach to this problem,
|
||||
// it actually simplifies a few things. I.e. If we miss a new shard
|
||||
// while AWS is resharding, we'll pick it up max 30 seconds later.
|
||||
|
|
@ -51,7 +51,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
|
|||
var ticker = time.NewTicker(30 * time.Second)
|
||||
|
||||
for {
|
||||
if err := g.findNewShards(ctx, shardc); err != nil {
|
||||
if err := g.findNewShards(ctx, shardC); err != nil {
|
||||
ticker.Stop()
|
||||
return err
|
||||
}
|
||||
|
|
@ -94,7 +94,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
|
|||
// findNewShards pulls the list of shards from the Kinesis API
|
||||
// and uses a local cache to determine if we are already processing
|
||||
// a particular shard.
|
||||
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error {
|
||||
func (g *AllGroup) findNewShards(ctx context.Context, shardC chan types.Shard) error {
|
||||
g.shardMu.Lock()
|
||||
defer g.shardMu.Unlock()
|
||||
|
||||
|
|
@ -110,14 +110,17 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e
|
|||
// channels before we start using any of them. It's highly probable
|
||||
// that Kinesis provides us the shards in dependency order (parents
|
||||
// before children), but it doesn't appear to be a guarantee.
|
||||
newShards := make(map[string]types.Shard)
|
||||
for _, shard := range shards {
|
||||
if _, ok := g.shards[*shard.ShardId]; ok {
|
||||
continue
|
||||
}
|
||||
g.shards[*shard.ShardId] = shard
|
||||
g.shardsClosed[*shard.ShardId] = make(chan struct{})
|
||||
newShards[*shard.ShardId] = shard
|
||||
}
|
||||
for _, shard := range shards {
|
||||
// only new shards need to be checked for parent dependencies
|
||||
for _, shard := range newShards {
|
||||
shard := shard // Shadow shard, since we use it in goroutine
|
||||
var parent1, parent2 <-chan struct{}
|
||||
if shard.ParentShardId != nil {
|
||||
|
|
@ -133,7 +136,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e
|
|||
// but when splits or joins happen, we need to process all parents prior
|
||||
// to processing children or that ordering guarantee is not maintained.
|
||||
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
|
||||
shardc <- shard
|
||||
shardC <- shard
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
|||
34
consumer.go
34
consumer.go
|
|
@ -4,7 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -38,7 +38,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
|||
store: &noopStore{},
|
||||
counter: &noopCounter{},
|
||||
logger: &noopLogger{
|
||||
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
||||
logger: log.New(io.Discard, "", log.LstdFlags),
|
||||
},
|
||||
scanInterval: 250 * time.Millisecond,
|
||||
maxRecords: 10000,
|
||||
|
|
@ -102,25 +102,35 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
defer cancel()
|
||||
|
||||
var (
|
||||
errc = make(chan error, 1)
|
||||
shardc = make(chan types.Shard, 1)
|
||||
errC = make(chan error, 1)
|
||||
shardC = make(chan types.Shard, 1)
|
||||
)
|
||||
|
||||
go func() {
|
||||
err := c.group.Start(ctx, shardc)
|
||||
err := c.group.Start(ctx, shardC)
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("error starting scan: %w", err)
|
||||
errC <- fmt.Errorf("error starting scan: %w", err)
|
||||
cancel()
|
||||
}
|
||||
<-ctx.Done()
|
||||
close(shardc)
|
||||
close(shardC)
|
||||
}()
|
||||
|
||||
wg := new(sync.WaitGroup)
|
||||
// process each of the shards
|
||||
for shard := range shardc {
|
||||
shardsInProcess := make(map[string]struct{})
|
||||
for shard := range shardC {
|
||||
shardId := aws.ToString(shard.ShardId)
|
||||
if _, ok := shardsInProcess[shardId]; ok {
|
||||
// safetynet: if shard already in process by another goroutine, just skipping the request
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(shardID string) {
|
||||
shardsInProcess[shardID] = struct{}{}
|
||||
defer func() {
|
||||
delete(shardsInProcess, shardID)
|
||||
}()
|
||||
defer wg.Done()
|
||||
var err error
|
||||
if err = c.ScanShard(ctx, shardID, fn); err != nil {
|
||||
|
|
@ -132,20 +142,20 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case errc <- fmt.Errorf("shard %s error: %w", shardID, err):
|
||||
case errC <- fmt.Errorf("shard %s error: %w", shardID, err):
|
||||
cancel()
|
||||
default:
|
||||
}
|
||||
}
|
||||
}(aws.ToString(shard.ShardId))
|
||||
}(shardId)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(errc)
|
||||
close(errC)
|
||||
}()
|
||||
|
||||
return <-errc
|
||||
return <-errC
|
||||
}
|
||||
|
||||
// ScanShard loops over records on a specific shard, calls the callback func
|
||||
|
|
|
|||
Loading…
Reference in a new issue