diff --git a/consumer.go b/consumer.go index 3f44ba1..7d155b6 100644 --- a/consumer.go +++ b/consumer.go @@ -80,6 +80,7 @@ type Consumer struct { scanInterval time.Duration maxRecords int64 isAggregated bool + shardClosedHandler ShardClosedHandler } // ScanFunc is the type of the function called for each message read @@ -215,6 +216,12 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e if isShardClosed(resp.NextShardIterator, shardIterator) { c.logger.Log("[CONSUMER] shard closed:", shardID) + if c.shardClosedHandler != nil { + err := c.shardClosedHandler(c.streamName, shardID) + if err != nil { + return fmt.Errorf("shard closed handler error: %w", err) + } + } return nil } diff --git a/consumer_test.go b/consumer_test.go index c070005..b7e5c1d 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -280,6 +280,43 @@ func TestScanShard_ShardIsClosed(t *testing.T) { } } +func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) { + var client = &kinesisClientMock{ + getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: make([]*kinesis.Record, 0), + }, nil + }, + } + + var fn = func(r *Record) error { + return nil + } + + c, err := New("myStreamName", + WithClient(client), + WithShardClosedHandler(func(streamName, shardID string) error { + return fmt.Errorf("closed shard error") + })) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + err = c.ScanShard(context.Background(), "myShard", fn) + if err == nil { + t.Fatal("expected an error but didn't get one") + } + if err.Error() != "shard closed handler error: closed shard error" { + t.Fatalf("unexpected error: %s", err.Error()) + } +} + func TestScanShard_GetRecordsError(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { diff --git a/options.go b/options.go index 093547b..c1080bf 100644 --- a/options.go +++ b/options.go @@ -79,3 +79,14 @@ func WithAggregation(a bool) Option { c.isAggregated = a } } + +// ShardClosedHandler is a handler that will be called when the consumer has reached the end of a closed shard. +// No more records for that shard will be provided by the consumer. +// An error can be returned to stop the consumer. +type ShardClosedHandler = func(streamName, shardID string) error + +func WithShardClosedHandler(h ShardClosedHandler) Option { + return func(c *Consumer) { + c.shardClosedHandler = h + } +}