diff --git a/consumer.go b/consumer.go index f16dcf4..cedae08 100644 --- a/consumer.go +++ b/consumer.go @@ -242,34 +242,26 @@ func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) Scan func (c *Consumer) getShardIDs(streamName string) ([]string, error) { var ss []string - var exclusiveStartShardId *string + var listShardsInput = &kinesis.ListShardsInput{ + StreamName: aws.String(streamName), + } for { - resp, err := c.client.DescribeStream( - &kinesis.DescribeStreamInput{ - StreamName: aws.String(streamName), - ExclusiveStartShardId: exclusiveStartShardId, - }, - ) + resp, err := c.client.ListShards(listShardsInput) if err != nil { - return nil, fmt.Errorf("describe stream error: %v", err) + return nil, fmt.Errorf("ListShards error: %v", err) } - streamDescription := resp.StreamDescription - shards := streamDescription.Shards - - if len(shards) == 0 { - return ss, nil - } - - for _, shard := range shards { + for _, shard := range resp.Shards { ss = append(ss, *shard.ShardId) } - exclusiveStartShardId = shards[len(shards)-1].ShardId - - if *streamDescription.HasMoreShards == false { + if resp.NextToken == nil { return ss, nil } + + listShardsInput = &kinesis.ListShardsInput{ + NextToken: resp.NextToken, + } } } diff --git a/consumer_test.go b/consumer_test.go index 0744e09..f669381 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -40,13 +40,11 @@ func TestConsumer_Scan(t *testing.T) { Records: records, }, nil }, - describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { - return &kinesis.DescribeStreamOutput{ - StreamDescription: &kinesis.StreamDescription{ - Shards: []*kinesis.Shard{ - {ShardId: aws.String("myShard")}, - }, - HasMoreShards: aws.Bool(false)}, + listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: []*kinesis.Shard{ + {ShardId: aws.String("myShard")}, + }, }, nil }, } @@ -94,11 +92,9 @@ func TestConsumer_Scan(t *testing.T) { func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { client := &kinesisClientMock{ - describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { - return &kinesis.DescribeStreamOutput{ - StreamDescription: &kinesis.StreamDescription{ - Shards: make([]*kinesis.Shard, 0), - }, + listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: make([]*kinesis.Shard, 0), }, nil }, } @@ -287,7 +283,11 @@ type kinesisClientMock struct { kinesisiface.KinesisAPI getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) - describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) + listShardsMock func(*kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) +} + +func (c *kinesisClientMock) ListShards(in *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { + return c.listShardsMock(in) } func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { @@ -298,10 +298,6 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) return c.getShardIteratorMock(in) } -func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { - return c.describeStreamMock(in) -} - // implementation of checkpoint type fakeCheckpoint struct { cache map[string]string