diff --git a/consumer.go b/consumer.go index d33a4ec..cedae08 100644 --- a/consumer.go +++ b/consumer.go @@ -72,7 +72,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) { // new consumer with no-op checkpoint, counter, and logger c := &Consumer{ streamName: streamName, - initialShardIteratorType: "TRIM_HORIZON", + initialShardIteratorType: kinesis.ShardIteratorTypeTrimHorizon, checkpoint: &noopCheckpoint{}, counter: &noopCounter{}, logger: &noopLogger{ @@ -241,20 +241,28 @@ func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) Scan } func (c *Consumer) getShardIDs(streamName string) ([]string, error) { - resp, err := c.client.DescribeStream( - &kinesis.DescribeStreamInput{ - StreamName: aws.String(streamName), - }, - ) - if err != nil { - return nil, fmt.Errorf("describe stream error: %v", err) - } - var ss []string - for _, shard := range resp.StreamDescription.Shards { - ss = append(ss, *shard.ShardId) + var listShardsInput = &kinesis.ListShardsInput{ + StreamName: aws.String(streamName), + } + for { + resp, err := c.client.ListShards(listShardsInput) + if err != nil { + return nil, fmt.Errorf("ListShards error: %v", err) + } + + for _, shard := range resp.Shards { + ss = append(ss, *shard.ShardId) + } + + if resp.NextToken == nil { + return ss, nil + } + + listShardsInput = &kinesis.ListShardsInput{ + NextToken: resp.NextToken, + } } - return ss, nil } func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) { @@ -264,7 +272,7 @@ func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*st } if lastSeqNum != "" { - params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER") + params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber) params.StartingSequenceNumber = aws.String(lastSeqNum) } else { params.ShardIteratorType = aws.String(c.initialShardIteratorType) diff --git a/consumer_test.go b/consumer_test.go index 36d6703..f669381 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -40,12 +40,10 @@ func TestConsumer_Scan(t *testing.T) { Records: records, }, nil }, - describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { - return &kinesis.DescribeStreamOutput{ - StreamDescription: &kinesis.StreamDescription{ - Shards: []*kinesis.Shard{ - {ShardId: aws.String("myShard")}, - }, + listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: []*kinesis.Shard{ + {ShardId: aws.String("myShard")}, }, }, nil }, @@ -94,11 +92,9 @@ func TestConsumer_Scan(t *testing.T) { func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { client := &kinesisClientMock{ - describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { - return &kinesis.DescribeStreamOutput{ - StreamDescription: &kinesis.StreamDescription{ - Shards: make([]*kinesis.Shard, 0), - }, + listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { + return &kinesis.ListShardsOutput{ + Shards: make([]*kinesis.Shard, 0), }, nil }, } @@ -287,7 +283,11 @@ type kinesisClientMock struct { kinesisiface.KinesisAPI getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) - describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) + listShardsMock func(*kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) +} + +func (c *kinesisClientMock) ListShards(in *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { + return c.listShardsMock(in) } func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { @@ -298,10 +298,6 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) return c.getShardIteratorMock(in) } -func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { - return c.describeStreamMock(in) -} - // implementation of checkpoint type fakeCheckpoint struct { cache map[string]string