Support creating an iterator with an initial timestamp (#99)

* Allow setting initial timestamp

* Fix writing to closed channel

* Allow cancelling of request
This commit is contained in:
Andrew Shannon Brown 2019-08-14 09:33:35 -07:00 committed by Harlow Ward
parent 81a8ac4221
commit 14db23eaf3
3 changed files with 36 additions and 7 deletions

View file

@ -6,6 +6,8 @@ import (
"fmt"
"io/ioutil"
"log"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@ -61,6 +63,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
type Consumer struct {
streamName string
initialShardIteratorType string
initialTimestamp *time.Time
client kinesisiface.KinesisAPI
counter Counter
group Group
@ -98,22 +101,29 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
close(shardc)
}()
wg := new(sync.WaitGroup)
// process each of the shards
for shard := range shardc {
wg.Add(1)
go func(shardID string) {
defer wg.Done()
if err := c.ScanShard(ctx, shardID, fn); err != nil {
select {
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
// first error to occur
cancel()
default:
// error has already occured
// error has already occurred
}
}
}(aws.StringValue(shard.ShardId))
}
close(errc)
go func() {
wg.Wait()
close(errc)
}()
return <-errc
}
@ -127,7 +137,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
}
// get shard iterator
shardIterator, err := c.getShardIterator(c.streamName, shardID, lastSeqNum)
shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
@ -156,7 +166,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
}
}
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
@ -205,7 +215,7 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
@ -214,10 +224,13 @@ func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string
if seqNum != "" {
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
params.StartingSequenceNumber = aws.String(seqNum)
} else if c.initialTimestamp != nil {
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAtTimestamp)
params.Timestamp = c.initialTimestamp
} else {
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
}
res, err := c.client.GetShardIterator(params)
res, err := c.client.GetShardIteratorWithContext(aws.Context(ctx), params)
return res.ShardIterator, err
}

View file

@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
@ -330,6 +331,10 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput)
return c.getShardIteratorMock(in)
}
func (c *kinesisClientMock) GetShardIteratorWithContext(ctx aws.Context, in *kinesis.GetShardIteratorInput, options ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
return c.getShardIteratorMock(in)
}
// implementation of checkpoint
type fakeCheckpoint struct {
cache map[string]string

View file

@ -1,6 +1,10 @@
package consumer
import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
import (
"time"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer)
@ -46,3 +50,10 @@ func WithShardIteratorType(t string) Option {
c.initialShardIteratorType = t
}
}
// Timestamp overrides the starting point for the consumer
func WithTimestamp(t time.Time) Option {
return func(c *Consumer) {
c.initialTimestamp = &t
}
}