From 6fdb1209b5bdc6278358ab4255c89e479aa7d23f Mon Sep 17 00:00:00 2001 From: Mikhail Date: Mon, 16 Sep 2024 01:22:29 +1000 Subject: [PATCH] fixes --- allgroup.go | 13 ++++++------- consumer.go | 22 ++++++++-------------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/allgroup.go b/allgroup.go index 1ecb7b2..3374086 100644 --- a/allgroup.go +++ b/allgroup.go @@ -9,7 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kinesis/types" ) -// NewAllGroup returns an intitialized AllGroup for consuming +// NewAllGroup returns an initialized AllGroup for consuming // all shards on a stream func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { return &AllGroup{ @@ -40,10 +40,10 @@ type AllGroup struct { // shards on a regular cadence. func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { // Note: while ticker is a rather naive approach to this problem, - // it actually simplifies a few things. i.e. If we miss a new shard - // while AWS is resharding we'll pick it up max 30 seconds later. + // it actually simplifies a few things. I.e. If we miss a new shard + // while AWS is resharding, we'll pick it up max 30 seconds later. - // It might be worth refactoring this flow to allow the consumer to + // It might be worth refactoring this flow to allow the consumer // to notify the broker when a shard is closed. However, shards don't // necessarily close at the same time, so we could potentially get a // thundering heard of notifications from the consumer. @@ -51,8 +51,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { var ticker = time.NewTicker(30 * time.Second) for { - err := g.findNewShards(ctx, shardc) - if err != nil { + if err := g.findNewShards(ctx, shardc); err != nil { ticker.Stop() return err } @@ -66,7 +65,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { } } -func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error { +func (g *AllGroup) CloseShard(_ context.Context, shardID string) error { g.shardMu.Lock() defer g.shardMu.Unlock() c, ok := g.shardsClosed[shardID] diff --git a/consumer.go b/consumer.go index 80ab45c..385ff46 100644 --- a/consumer.go +++ b/consumer.go @@ -90,7 +90,7 @@ type Consumer struct { type ScanFunc func(*Record) error // ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that -// the current checkpoint should be skipped skipped. It is not returned +// the current checkpoint should be skipped. It is not returned // as an error by any function. var ErrSkipCheckpoint = errors.New("skip checkpoint") @@ -148,10 +148,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { return <-errc } -func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error { - return nil -} - // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { @@ -213,14 +209,12 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return nil default: err := fn(&Record{r, shardID, resp.MillisBehindLatest}) - if err != nil && err != ErrSkipCheckpoint { + if err != nil && !errors.Is(err, ErrSkipCheckpoint) { return err } - if err != ErrSkipCheckpoint { - if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { - return err - } + if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { + return err } c.counter.Add("records", 1) @@ -284,7 +278,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp params.Timestamp = c.initialTimestamp } else { - params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType) + params.ShardIteratorType = c.initialShardIteratorType } res, err := c.client.GetShardIterator(ctx, params) @@ -295,10 +289,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } func isRetriableError(err error) bool { - switch err.(type) { - case *types.ExpiredIteratorException: + if oe := (*types.ExpiredIteratorException)(nil); errors.As(err, &oe) { return true - case *types.ProvisionedThroughputExceededException: + } + if oe := (*types.ProvisionedThroughputExceededException)(nil); errors.As(err, &oe) { return true } return false