Refactor getShardIDs (#70)

* Refactor

* Use `nextToken` paramter as string.

Use `nextToken` paramter as string instead of pointer to match the original code base.

* Log the last shard token when failing.

* Use aws.StringValue to get the string pointer value.

Co-authored-by: Wesam Gerges <wesam.gerges.discovery@gmail.com>
This commit is contained in:
wgerges-discovery 2020-03-10 12:00:57 -04:00 committed by Tao Jiang
parent 5dd53bf731
commit 384482169c

View file

@ -28,7 +28,6 @@
package worker
import (
"errors"
"math/rand"
"sync"
"time"
@ -325,32 +324,27 @@ func (w *Worker) eventLoop() {
}
}
// List all ACTIVE shard and store them into shardStatus table
// List all shards and store them into shardStatus table
// If shard has been removed, need to exclude it from cached shard status.
func (w *Worker) getShardIDs(startShardID string, shardInfo map[string]bool) error {
func (w *Worker) getShardIDs(nextToken string, shardInfo map[string]bool) error {
log := w.kclConfig.Logger
// The default pagination limit is 100.
args := &kinesis.DescribeStreamInput{
StreamName: aws.String(w.streamName),
args := &kinesis.ListShardsInput{}
// When you have a nextToken, you can't set the streamName
if nextToken != "" {
args.NextToken = aws.String(nextToken)
} else {
args.StreamName = aws.String(w.streamName)
}
if startShardID != "" {
args.ExclusiveStartShardId = aws.String(startShardID)
}
streamDesc, err := w.kc.DescribeStream(args)
listShards, err := w.kc.ListShards(args)
if err != nil {
log.Errorf("Error in DescribeStream: %s Error: %+v Request: %s", w.streamName, err, args)
log.Errorf("Error in ListShards: %s Error: %+v Request: %s", w.streamName, err, args)
return err
}
if *streamDesc.StreamDescription.StreamStatus != "ACTIVE" {
log.Warnf("Stream %s is not active", w.streamName)
return errors.New("stream not active")
}
var lastShardID string
for _, s := range streamDesc.StreamDescription.Shards {
for _, s := range listShards.Shards {
// record avail shardId from fresh reading from Kinesis
shardInfo[*s.ShardId] = true
@ -365,13 +359,12 @@ func (w *Worker) getShardIDs(startShardID string, shardInfo map[string]bool) err
EndingSequenceNumber: aws.StringValue(s.SequenceNumberRange.EndingSequenceNumber),
}
}
lastShardID = *s.ShardId
}
if *streamDesc.StreamDescription.HasMoreShards {
err := w.getShardIDs(lastShardID, shardInfo)
if listShards.NextToken != nil {
err := w.getShardIDs(aws.StringValue(listShards.NextToken), shardInfo)
if err != nil {
log.Errorf("Error in getShardIDs: %s Error: %+v", lastShardID, err)
log.Errorf("Error in ListShards: %s Error: %+v Request: %s", w.streamName, err, args)
return err
}
}