From 2f0c13ed7264d0bd56daa19d5431ede6083bc93e Mon Sep 17 00:00:00 2001 From: Farhan Date: Sun, 3 Feb 2019 13:27:58 +0530 Subject: [PATCH] Fix getShardIds --- consumer.go | 40 ++++++++++++++++++++++++++++------------ consumer_test.go | 2 +- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/consumer.go b/consumer.go index d33a4ec..ec3ab7d 100644 --- a/consumer.go +++ b/consumer.go @@ -241,20 +241,36 @@ func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) Scan } func (c *Consumer) getShardIDs(streamName string) ([]string, error) { - resp, err := c.client.DescribeStream( - &kinesis.DescribeStreamInput{ - StreamName: aws.String(streamName), - }, - ) - if err != nil { - return nil, fmt.Errorf("describe stream error: %v", err) - } - var ss []string - for _, shard := range resp.StreamDescription.Shards { - ss = append(ss, *shard.ShardId) + var exclusiveStartShardId *string + for { + resp, err := c.client.DescribeStream( + &kinesis.DescribeStreamInput{ + StreamName: aws.String(streamName), + ExclusiveStartShardId: exclusiveStartShardId, + }, + ) + if err != nil { + return nil, fmt.Errorf("describe stream error: %v", err) + } + + streamDescription := resp.StreamDescription + shards := streamDescription.Shards + + if len(shards) == 0 { + return ss, nil + } + + for _, shard := range shards { + ss = append(ss, *shard.ShardId) + } + + exclusiveStartShardId = shards[len(shards)-1].ShardId + + if *streamDescription.HasMoreShards == false { + return ss, nil + } } - return ss, nil } func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) { diff --git a/consumer_test.go b/consumer_test.go index 36d6703..0744e09 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -46,7 +46,7 @@ func TestConsumer_Scan(t *testing.T) { Shards: []*kinesis.Shard{ {ShardId: aws.String("myShard")}, }, - }, + HasMoreShards: aws.Bool(false)}, }, nil }, }