Fix getShardIds
This commit is contained in:
parent
2f58b136fe
commit
2f0c13ed72
2 changed files with 29 additions and 13 deletions
40
consumer.go
40
consumer.go
|
|
@ -241,20 +241,36 @@ func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) Scan
|
|||
}
|
||||
|
||||
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||
resp, err := c.client.DescribeStream(
|
||||
&kinesis.DescribeStreamInput{
|
||||
StreamName: aws.String(streamName),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("describe stream error: %v", err)
|
||||
}
|
||||
|
||||
var ss []string
|
||||
for _, shard := range resp.StreamDescription.Shards {
|
||||
ss = append(ss, *shard.ShardId)
|
||||
var exclusiveStartShardId *string
|
||||
for {
|
||||
resp, err := c.client.DescribeStream(
|
||||
&kinesis.DescribeStreamInput{
|
||||
StreamName: aws.String(streamName),
|
||||
ExclusiveStartShardId: exclusiveStartShardId,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("describe stream error: %v", err)
|
||||
}
|
||||
|
||||
streamDescription := resp.StreamDescription
|
||||
shards := streamDescription.Shards
|
||||
|
||||
if len(shards) == 0 {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
for _, shard := range shards {
|
||||
ss = append(ss, *shard.ShardId)
|
||||
}
|
||||
|
||||
exclusiveStartShardId = shards[len(shards)-1].ShardId
|
||||
|
||||
if *streamDescription.HasMoreShards == false {
|
||||
return ss, nil
|
||||
}
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ func TestConsumer_Scan(t *testing.T) {
|
|||
Shards: []*kinesis.Shard{
|
||||
{ShardId: aws.String("myShard")},
|
||||
},
|
||||
},
|
||||
HasMoreShards: aws.Bool(false)},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue