From 89b383161aa50bef7b94635f138494410344caaf Mon Sep 17 00:00:00 2001 From: "Andrew S. Brown" Date: Fri, 9 Aug 2019 08:04:36 -0700 Subject: [PATCH] Allow cancelling of request --- consumer.go | 8 ++++---- consumer_test.go | 5 +++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/consumer.go b/consumer.go index 8121e95..6d56060 100644 --- a/consumer.go +++ b/consumer.go @@ -137,7 +137,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } // get shard iterator - shardIterator, err := c.getShardIterator(c.streamName, shardID, lastSeqNum) + shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) if err != nil { return fmt.Errorf("get shard iterator error: %v", err) } @@ -166,7 +166,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } } - shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) + shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) if err != nil { return fmt.Errorf("get shard iterator error: %v", err) } @@ -215,7 +215,7 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } -func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) { +func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) { params := &kinesis.GetShardIteratorInput{ ShardId: aws.String(shardID), StreamName: aws.String(streamName), @@ -231,6 +231,6 @@ func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string params.ShardIteratorType = aws.String(c.initialShardIteratorType) } - res, err := c.client.GetShardIterator(params) + res, err := c.client.GetShardIteratorWithContext(aws.Context(ctx), params) return res.ShardIterator, err } diff --git a/consumer_test.go b/consumer_test.go index 48696b6..2ec2ec4 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" ) @@ -330,6 +331,10 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) return c.getShardIteratorMock(in) } +func (c *kinesisClientMock) GetShardIteratorWithContext(ctx aws.Context, in *kinesis.GetShardIteratorInput, options ...request.Option) (*kinesis.GetShardIteratorOutput, error) { + return c.getShardIteratorMock(in) +} + // implementation of checkpoint type fakeCheckpoint struct { cache map[string]string