This commit is contained in:
Mikhail 2024-09-16 01:22:29 +10:00
parent 553e2392fd
commit 6fdb1209b5
No known key found for this signature in database
GPG key ID: 6FFFEA01DBC79BFC
2 changed files with 14 additions and 21 deletions

View file

@ -9,7 +9,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/aws/aws-sdk-go-v2/service/kinesis/types"
) )
// NewAllGroup returns an intitialized AllGroup for consuming // NewAllGroup returns an initialized AllGroup for consuming
// all shards on a stream // all shards on a stream
func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup {
return &AllGroup{ return &AllGroup{
@ -40,10 +40,10 @@ type AllGroup struct {
// shards on a regular cadence. // shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
// Note: while ticker is a rather naive approach to this problem, // Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. i.e. If we miss a new shard // it actually simplifies a few things. I.e. If we miss a new shard
// while AWS is resharding we'll pick it up max 30 seconds later. // while AWS is resharding, we'll pick it up max 30 seconds later.
// It might be worth refactoring this flow to allow the consumer to // It might be worth refactoring this flow to allow the consumer
// to notify the broker when a shard is closed. However, shards don't // to notify the broker when a shard is closed. However, shards don't
// necessarily close at the same time, so we could potentially get a // necessarily close at the same time, so we could potentially get a
// thundering heard of notifications from the consumer. // thundering heard of notifications from the consumer.
@ -51,8 +51,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
var ticker = time.NewTicker(30 * time.Second) var ticker = time.NewTicker(30 * time.Second)
for { for {
err := g.findNewShards(ctx, shardc) if err := g.findNewShards(ctx, shardc); err != nil {
if err != nil {
ticker.Stop() ticker.Stop()
return err return err
} }
@ -66,7 +65,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
} }
} }
func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error { func (g *AllGroup) CloseShard(_ context.Context, shardID string) error {
g.shardMu.Lock() g.shardMu.Lock()
defer g.shardMu.Unlock() defer g.shardMu.Unlock()
c, ok := g.shardsClosed[shardID] c, ok := g.shardsClosed[shardID]

View file

@ -90,7 +90,7 @@ type Consumer struct {
type ScanFunc func(*Record) error type ScanFunc func(*Record) error
// ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that // ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that
// the current checkpoint should be skipped skipped. It is not returned // the current checkpoint should be skipped. It is not returned
// as an error by any function. // as an error by any function.
var ErrSkipCheckpoint = errors.New("skip checkpoint") var ErrSkipCheckpoint = errors.New("skip checkpoint")
@ -148,10 +148,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
return <-errc return <-errc
} }
func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error {
return nil
}
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
@ -213,14 +209,12 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
return nil return nil
default: default:
err := fn(&Record{r, shardID, resp.MillisBehindLatest}) err := fn(&Record{r, shardID, resp.MillisBehindLatest})
if err != nil && err != ErrSkipCheckpoint { if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
return err return err
} }
if err != ErrSkipCheckpoint { if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { return err
return err
}
} }
c.counter.Add("records", 1) c.counter.Add("records", 1)
@ -284,7 +278,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp
params.Timestamp = c.initialTimestamp params.Timestamp = c.initialTimestamp
} else { } else {
params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType) params.ShardIteratorType = c.initialShardIteratorType
} }
res, err := c.client.GetShardIterator(ctx, params) res, err := c.client.GetShardIterator(ctx, params)
@ -295,10 +289,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
} }
func isRetriableError(err error) bool { func isRetriableError(err error) bool {
switch err.(type) { if oe := (*types.ExpiredIteratorException)(nil); errors.As(err, &oe) {
case *types.ExpiredIteratorException:
return true return true
case *types.ProvisionedThroughputExceededException: }
if oe := (*types.ProvisionedThroughputExceededException)(nil); errors.As(err, &oe) {
return true return true
} }
return false return false