From 14db23eaf34c2d559a90fa985beea710a831d10b Mon Sep 17 00:00:00 2001 From: Andrew Shannon Brown Date: Wed, 14 Aug 2019 09:33:35 -0700 Subject: [PATCH] Support creating an iterator with an initial timestamp (#99) * Allow setting initial timestamp * Fix writing to closed channel * Allow cancelling of request --- consumer.go | 25 +++++++++++++++++++------ consumer_test.go | 5 +++++ options.go | 13 ++++++++++++- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/consumer.go b/consumer.go index 507076b..6d56060 100644 --- a/consumer.go +++ b/consumer.go @@ -6,6 +6,8 @@ import ( "fmt" "io/ioutil" "log" + "sync" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" @@ -61,6 +63,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) { type Consumer struct { streamName string initialShardIteratorType string + initialTimestamp *time.Time client kinesisiface.KinesisAPI counter Counter group Group @@ -98,22 +101,29 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { close(shardc) }() + wg := new(sync.WaitGroup) // process each of the shards for shard := range shardc { + wg.Add(1) go func(shardID string) { + defer wg.Done() if err := c.ScanShard(ctx, shardID, fn); err != nil { select { case errc <- fmt.Errorf("shard %s error: %v", shardID, err): // first error to occur cancel() default: - // error has already occured + // error has already occurred } } }(aws.StringValue(shard.ShardId)) } - close(errc) + go func() { + wg.Wait() + close(errc) + }() + return <-errc } @@ -127,7 +137,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } // get shard iterator - shardIterator, err := c.getShardIterator(c.streamName, shardID, lastSeqNum) + shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) if err != nil { return fmt.Errorf("get shard iterator error: %v", err) } @@ -156,7 +166,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } } - shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) + shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) if err != nil { return fmt.Errorf("get shard iterator error: %v", err) } @@ -205,7 +215,7 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } -func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) { +func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) { params := &kinesis.GetShardIteratorInput{ ShardId: aws.String(shardID), StreamName: aws.String(streamName), @@ -214,10 +224,13 @@ func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string if seqNum != "" { params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber) params.StartingSequenceNumber = aws.String(seqNum) + } else if c.initialTimestamp != nil { + params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAtTimestamp) + params.Timestamp = c.initialTimestamp } else { params.ShardIteratorType = aws.String(c.initialShardIteratorType) } - res, err := c.client.GetShardIterator(params) + res, err := c.client.GetShardIteratorWithContext(aws.Context(ctx), params) return res.ShardIterator, err } diff --git a/consumer_test.go b/consumer_test.go index 48696b6..2ec2ec4 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" ) @@ -330,6 +331,10 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) return c.getShardIteratorMock(in) } +func (c *kinesisClientMock) GetShardIteratorWithContext(ctx aws.Context, in *kinesis.GetShardIteratorInput, options ...request.Option) (*kinesis.GetShardIteratorOutput, error) { + return c.getShardIteratorMock(in) +} + // implementation of checkpoint type fakeCheckpoint struct { cache map[string]string diff --git a/options.go b/options.go index dd77da0..be306da 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,10 @@ package consumer -import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" +import ( + "time" + + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" +) // Option is used to override defaults when creating a new Consumer type Option func(*Consumer) @@ -46,3 +50,10 @@ func WithShardIteratorType(t string) Option { c.initialShardIteratorType = t } } + +// Timestamp overrides the starting point for the consumer +func WithTimestamp(t time.Time) Option { + return func(c *Consumer) { + c.initialTimestamp = &t + } +}