Fix getShardID does not return more than 100 shards (#81)
This commit is contained in:
parent
2f58b136fe
commit
2037463c62
2 changed files with 34 additions and 30 deletions
36
consumer.go
36
consumer.go
|
|
@ -72,7 +72,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
|||
// new consumer with no-op checkpoint, counter, and logger
|
||||
c := &Consumer{
|
||||
streamName: streamName,
|
||||
initialShardIteratorType: "TRIM_HORIZON",
|
||||
initialShardIteratorType: kinesis.ShardIteratorTypeTrimHorizon,
|
||||
checkpoint: &noopCheckpoint{},
|
||||
counter: &noopCounter{},
|
||||
logger: &noopLogger{
|
||||
|
|
@ -241,20 +241,28 @@ func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) Scan
|
|||
}
|
||||
|
||||
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||
resp, err := c.client.DescribeStream(
|
||||
&kinesis.DescribeStreamInput{
|
||||
StreamName: aws.String(streamName),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("describe stream error: %v", err)
|
||||
}
|
||||
|
||||
var ss []string
|
||||
for _, shard := range resp.StreamDescription.Shards {
|
||||
ss = append(ss, *shard.ShardId)
|
||||
var listShardsInput = &kinesis.ListShardsInput{
|
||||
StreamName: aws.String(streamName),
|
||||
}
|
||||
for {
|
||||
resp, err := c.client.ListShards(listShardsInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ListShards error: %v", err)
|
||||
}
|
||||
|
||||
for _, shard := range resp.Shards {
|
||||
ss = append(ss, *shard.ShardId)
|
||||
}
|
||||
|
||||
if resp.NextToken == nil {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
listShardsInput = &kinesis.ListShardsInput{
|
||||
NextToken: resp.NextToken,
|
||||
}
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
|
||||
|
|
@ -264,7 +272,7 @@ func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*st
|
|||
}
|
||||
|
||||
if lastSeqNum != "" {
|
||||
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
|
||||
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
|
||||
params.StartingSequenceNumber = aws.String(lastSeqNum)
|
||||
} else {
|
||||
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
|
||||
|
|
|
|||
|
|
@ -40,12 +40,10 @@ func TestConsumer_Scan(t *testing.T) {
|
|||
Records: records,
|
||||
}, nil
|
||||
},
|
||||
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
||||
return &kinesis.DescribeStreamOutput{
|
||||
StreamDescription: &kinesis.StreamDescription{
|
||||
Shards: []*kinesis.Shard{
|
||||
{ShardId: aws.String("myShard")},
|
||||
},
|
||||
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
||||
return &kinesis.ListShardsOutput{
|
||||
Shards: []*kinesis.Shard{
|
||||
{ShardId: aws.String("myShard")},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
|
@ -94,11 +92,9 @@ func TestConsumer_Scan(t *testing.T) {
|
|||
|
||||
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
||||
client := &kinesisClientMock{
|
||||
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
||||
return &kinesis.DescribeStreamOutput{
|
||||
StreamDescription: &kinesis.StreamDescription{
|
||||
Shards: make([]*kinesis.Shard, 0),
|
||||
},
|
||||
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
||||
return &kinesis.ListShardsOutput{
|
||||
Shards: make([]*kinesis.Shard, 0),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
|
@ -287,7 +283,11 @@ type kinesisClientMock struct {
|
|||
kinesisiface.KinesisAPI
|
||||
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
|
||||
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
|
||||
describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error)
|
||||
listShardsMock func(*kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error)
|
||||
}
|
||||
|
||||
func (c *kinesisClientMock) ListShards(in *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
||||
return c.listShardsMock(in)
|
||||
}
|
||||
|
||||
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||
|
|
@ -298,10 +298,6 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput)
|
|||
return c.getShardIteratorMock(in)
|
||||
}
|
||||
|
||||
func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
||||
return c.describeStreamMock(in)
|
||||
}
|
||||
|
||||
// implementation of checkpoint
|
||||
type fakeCheckpoint struct {
|
||||
cache map[string]string
|
||||
|
|
|
|||
Loading…
Reference in a new issue