Support creating an iterator with an initial timestamp (#99)
* Allow setting initial timestamp * Fix writing to closed channel * Allow cancelling of request
This commit is contained in:
parent
81a8ac4221
commit
14db23eaf3
3 changed files with 36 additions and 7 deletions
25
consumer.go
25
consumer.go
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
|
@ -61,6 +63,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
type Consumer struct {
|
type Consumer struct {
|
||||||
streamName string
|
streamName string
|
||||||
initialShardIteratorType string
|
initialShardIteratorType string
|
||||||
|
initialTimestamp *time.Time
|
||||||
client kinesisiface.KinesisAPI
|
client kinesisiface.KinesisAPI
|
||||||
counter Counter
|
counter Counter
|
||||||
group Group
|
group Group
|
||||||
|
|
@ -98,22 +101,29 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
close(shardc)
|
close(shardc)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
// process each of the shards
|
// process each of the shards
|
||||||
for shard := range shardc {
|
for shard := range shardc {
|
||||||
|
wg.Add(1)
|
||||||
go func(shardID string) {
|
go func(shardID string) {
|
||||||
|
defer wg.Done()
|
||||||
if err := c.ScanShard(ctx, shardID, fn); err != nil {
|
if err := c.ScanShard(ctx, shardID, fn); err != nil {
|
||||||
select {
|
select {
|
||||||
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
|
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
|
||||||
// first error to occur
|
// first error to occur
|
||||||
cancel()
|
cancel()
|
||||||
default:
|
default:
|
||||||
// error has already occured
|
// error has already occurred
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(aws.StringValue(shard.ShardId))
|
}(aws.StringValue(shard.ShardId))
|
||||||
}
|
}
|
||||||
|
|
||||||
close(errc)
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(errc)
|
||||||
|
}()
|
||||||
|
|
||||||
return <-errc
|
return <-errc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -127,7 +137,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
}
|
}
|
||||||
|
|
||||||
// get shard iterator
|
// get shard iterator
|
||||||
shardIterator, err := c.getShardIterator(c.streamName, shardID, lastSeqNum)
|
shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get shard iterator error: %v", err)
|
return fmt.Errorf("get shard iterator error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -156,7 +166,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
|
shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get shard iterator error: %v", err)
|
return fmt.Errorf("get shard iterator error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -205,7 +215,7 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
|
||||||
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
|
func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) {
|
||||||
params := &kinesis.GetShardIteratorInput{
|
params := &kinesis.GetShardIteratorInput{
|
||||||
ShardId: aws.String(shardID),
|
ShardId: aws.String(shardID),
|
||||||
StreamName: aws.String(streamName),
|
StreamName: aws.String(streamName),
|
||||||
|
|
@ -214,10 +224,13 @@ func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string
|
||||||
if seqNum != "" {
|
if seqNum != "" {
|
||||||
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
|
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
|
||||||
params.StartingSequenceNumber = aws.String(seqNum)
|
params.StartingSequenceNumber = aws.String(seqNum)
|
||||||
|
} else if c.initialTimestamp != nil {
|
||||||
|
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAtTimestamp)
|
||||||
|
params.Timestamp = c.initialTimestamp
|
||||||
} else {
|
} else {
|
||||||
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
|
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := c.client.GetShardIterator(params)
|
res, err := c.client.GetShardIteratorWithContext(aws.Context(ctx), params)
|
||||||
return res.ShardIterator, err
|
return res.ShardIterator, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/request"
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
)
|
)
|
||||||
|
|
@ -330,6 +331,10 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput)
|
||||||
return c.getShardIteratorMock(in)
|
return c.getShardIteratorMock(in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *kinesisClientMock) GetShardIteratorWithContext(ctx aws.Context, in *kinesis.GetShardIteratorInput, options ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return c.getShardIteratorMock(in)
|
||||||
|
}
|
||||||
|
|
||||||
// implementation of checkpoint
|
// implementation of checkpoint
|
||||||
type fakeCheckpoint struct {
|
type fakeCheckpoint struct {
|
||||||
cache map[string]string
|
cache map[string]string
|
||||||
|
|
|
||||||
13
options.go
13
options.go
|
|
@ -1,6 +1,10 @@
|
||||||
package consumer
|
package consumer
|
||||||
|
|
||||||
import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
|
)
|
||||||
|
|
||||||
// Option is used to override defaults when creating a new Consumer
|
// Option is used to override defaults when creating a new Consumer
|
||||||
type Option func(*Consumer)
|
type Option func(*Consumer)
|
||||||
|
|
@ -46,3 +50,10 @@ func WithShardIteratorType(t string) Option {
|
||||||
c.initialShardIteratorType = t
|
c.initialShardIteratorType = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Timestamp overrides the starting point for the consumer
|
||||||
|
func WithTimestamp(t time.Time) Option {
|
||||||
|
return func(c *Consumer) {
|
||||||
|
c.initialTimestamp = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue