diff --git a/allgroup.go b/allgroup.go index d107a7a..1ecb7b2 100644 --- a/allgroup.go +++ b/allgroup.go @@ -38,7 +38,7 @@ type AllGroup struct { // Start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { +func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { // Note: while ticker is a rather naive approach to this problem, // it actually simplifies a few things. i.e. If we miss a new shard // while AWS is resharding we'll pick it up max 30 seconds later. @@ -51,12 +51,16 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { var ticker = time.NewTicker(30 * time.Second) for { - g.findNewShards(ctx, shardc) + err := g.findNewShards(ctx, shardc) + if err != nil { + ticker.Stop() + return err + } select { case <-ctx.Done(): ticker.Stop() - return + return nil case <-ticker.C: } } @@ -91,7 +95,7 @@ func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool { // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. -func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { +func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error { g.shardMu.Lock() defer g.shardMu.Unlock() @@ -100,7 +104,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { shards, err := listShards(ctx, g.ksis, g.streamName) if err != nil { g.logger.Log("[GROUP] error:", err) - return + return err } // We do two `for` loops, since we have to set up all the `shardClosed` @@ -134,4 +138,5 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { } }() } + return nil } diff --git a/consumer.go b/consumer.go index 8afb270..80ab45c 100644 --- a/consumer.go +++ b/consumer.go @@ -107,7 +107,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { ) go func() { - c.group.Start(ctx, shardc) + err := c.group.Start(ctx, shardc) + if err != nil { + errc <- fmt.Errorf("error starting scan: %w", err) + cancel() + } <-ctx.Done() close(shardc) }() @@ -284,7 +288,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } res, err := c.client.GetShardIterator(ctx, params) - return res.ShardIterator, err + if err != nil { + return nil, err + } + return res.ShardIterator, nil } func isRetriableError(err error) bool { diff --git a/consumer_test.go b/consumer_test.go index 13cfb04..d14b11e 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -2,6 +2,7 @@ package consumer import ( "context" + "errors" "fmt" "math/rand" "sync" @@ -110,6 +111,71 @@ func TestScan(t *testing.T) { } } +func TestScan_ListShardsError(t *testing.T) { + mockError := errors.New("mock list shards error") + client := &kinesisClientMock{ + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return nil, mockError + }, + } + + // use cancel func to signal shutdown + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.Scan(ctx, fn) + if !errors.Is(err, mockError) { + t.Errorf("expected an error from listShards, but instead got %v", err) + } +} + +func TestScan_GetShardIteratorError(t *testing.T) { + mockError := errors.New("mock get shard iterator error") + client := &kinesisClientMock{ + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: []types.Shard{ + {ShardId: aws.String("myShard")}, + }, + }, nil + }, + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return nil, mockError + }, + } + + // use cancel func to signal shutdown + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.Scan(ctx, fn) + if !errors.Is(err, mockError) { + t.Errorf("expected an error from getShardIterator, but instead got %v", err) + } +} + func TestScanShard(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { diff --git a/group.go b/group.go index 5856f24..29647d9 100644 --- a/group.go +++ b/group.go @@ -8,7 +8,7 @@ import ( // Group interface used to manage which shard to process type Group interface { - Start(ctx context.Context, shardc chan types.Shard) + Start(ctx context.Context, shardc chan types.Shard) error GetCheckpoint(streamName, shardID string) (string, error) SetCheckpoint(streamName, shardID, sequenceNumber string) error }