Fix getShardIds

This commit is contained in:
Farhan 2019-02-03 13:27:58 +05:30
parent 2f58b136fe
commit 2f0c13ed72
2 changed files with 29 additions and 13 deletions

View file

@ -241,20 +241,36 @@ func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) Scan
} }
func (c *Consumer) getShardIDs(streamName string) ([]string, error) { func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
var ss []string var ss []string
for _, shard := range resp.StreamDescription.Shards { var exclusiveStartShardId *string
ss = append(ss, *shard.ShardId) for {
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
ExclusiveStartShardId: exclusiveStartShardId,
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
streamDescription := resp.StreamDescription
shards := streamDescription.Shards
if len(shards) == 0 {
return ss, nil
}
for _, shard := range shards {
ss = append(ss, *shard.ShardId)
}
exclusiveStartShardId = shards[len(shards)-1].ShardId
if *streamDescription.HasMoreShards == false {
return ss, nil
}
} }
return ss, nil
} }
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) { func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {

View file

@ -46,7 +46,7 @@ func TestConsumer_Scan(t *testing.T) {
Shards: []*kinesis.Shard{ Shards: []*kinesis.Shard{
{ShardId: aws.String("myShard")}, {ShardId: aws.String("myShard")},
}, },
}, HasMoreShards: aws.Bool(false)},
}, nil }, nil
}, },
} }