Introduce ScanFunc signature and remove ScanStatus (#77)
Major changes: ```go type ScanFunc func(r *Record) error ``` * Simplify the callback func signature by removing `ScanStatus` * Leverage context for cancellation * Add custom error `SkipCheckpoint` for special cases when we don't want to checkpoint Minor changes: * Use kinesis package constants for shard iterator types * Move optional config to new file See conversation on #75 for more details
This commit is contained in:
parent
24de74fd14
commit
76158d24ab
9 changed files with 235 additions and 208 deletions
|
|
@ -6,6 +6,14 @@ All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
Major changes:
|
Major changes:
|
||||||
|
|
||||||
|
* Remove the concept of `ScanStatus` to simplify the scanning interface
|
||||||
|
|
||||||
|
For more context on this change see: https://github.com/harlow/kinesis-consumer/issues/75
|
||||||
|
|
||||||
|
## v0.3.0 - 2018-12-28
|
||||||
|
|
||||||
|
Major changes:
|
||||||
|
|
||||||
* Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client.
|
* Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client.
|
||||||
* Rename `ScanError` to `ScanStatus` as it's not always an error.
|
* Rename `ScanError` to `ScanStatus` as it's not always an error.
|
||||||
|
|
||||||
|
|
|
||||||
59
README.md
59
README.md
|
|
@ -40,39 +40,62 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// start scan
|
// start scan
|
||||||
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanStatus {
|
err = c.Scan(context.TODO(), func(r *consumer.Record) error {
|
||||||
fmt.Println(string(r.Data))
|
fmt.Println(string(r.Data))
|
||||||
|
return nil // continue scanning
|
||||||
return consumer.ScanStatus{
|
|
||||||
StopScan: false, // true to stop scan
|
|
||||||
SkipCheckpoint: false, // true to skip checkpoint
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("scan error: %v", err)
|
log.Fatalf("scan error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: If you need to aggregate based on a specific shard the `ScanShard`
|
// Note: If you need to aggregate based on a specific shard
|
||||||
// method should be leverged instead.
|
// the `ScanShard` function should be used instead.
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Scan status
|
## ScanFunc
|
||||||
|
|
||||||
The scan func returns a `consumer.ScanStatus` the struct allows some basic flow control.
|
ScanFunc is the type of the function called for each message read
|
||||||
|
from the stream. The record argument contains the original record
|
||||||
|
returned from the AWS Kinesis library.
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ScanFunc func(r *Record) error
|
||||||
|
```
|
||||||
|
|
||||||
|
If an error is returned, scanning stops. The sole exception is when the
|
||||||
|
function returns the special value SkipCheckpoint.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// continue scanning
|
// continue scanning
|
||||||
return consumer.ScanStatus{}
|
return nil
|
||||||
|
|
||||||
// continue scanning, skip saving checkpoint
|
// continue scanning, skip checkpoint
|
||||||
return consumer.ScanStatus{SkipCheckpoint: true}
|
return consumer.SkipCheckpoint
|
||||||
|
|
||||||
// stop scanning, return nil
|
|
||||||
return consumer.ScanStatus{StopScan: true}
|
|
||||||
|
|
||||||
// stop scanning, return error
|
// stop scanning, return error
|
||||||
return consumer.ScanStatus{Error: err}
|
return errors.New("my error, exit all scans")
|
||||||
|
```
|
||||||
|
|
||||||
|
Use context cancel to signal the scan to exit without error. For example if we wanted to gracefulloy exit the scan on interrupt.
|
||||||
|
|
||||||
|
```go
|
||||||
|
// trap SIGINT, wait to trigger shutdown
|
||||||
|
signals := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(signals, os.Interrupt)
|
||||||
|
|
||||||
|
// context with cancel
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-signals
|
||||||
|
cancel() // call cancellation
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := c.Scan(ctx, func(r *consumer.Record) error {
|
||||||
|
fmt.Println(string(r.Data))
|
||||||
|
return nil // continue scanning
|
||||||
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
## Checkpoint
|
## Checkpoint
|
||||||
|
|
@ -182,7 +205,7 @@ Override the Kinesis client if there is any special config needed:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// client
|
// client
|
||||||
client := kinesis.New(session.New(aws.NewConfig()))
|
client := kinesis.New(session.NewSession(aws.NewConfig()))
|
||||||
|
|
||||||
// consumer
|
// consumer
|
||||||
c, err := consumer.New(streamName, consumer.WithClient(client))
|
c, err := consumer.New(streamName, consumer.WithClient(client))
|
||||||
|
|
|
||||||
141
consumer.go
141
consumer.go
|
|
@ -2,6 +2,7 @@ package consumer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
|
@ -16,52 +17,6 @@ import (
|
||||||
// Record is an alias of record returned from kinesis library
|
// Record is an alias of record returned from kinesis library
|
||||||
type Record = kinesis.Record
|
type Record = kinesis.Record
|
||||||
|
|
||||||
// Option is used to override defaults when creating a new Consumer
|
|
||||||
type Option func(*Consumer)
|
|
||||||
|
|
||||||
// WithCheckpoint overrides the default checkpoint
|
|
||||||
func WithCheckpoint(checkpoint Checkpoint) Option {
|
|
||||||
return func(c *Consumer) {
|
|
||||||
c.checkpoint = checkpoint
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithLogger overrides the default logger
|
|
||||||
func WithLogger(logger Logger) Option {
|
|
||||||
return func(c *Consumer) {
|
|
||||||
c.logger = logger
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCounter overrides the default counter
|
|
||||||
func WithCounter(counter Counter) Option {
|
|
||||||
return func(c *Consumer) {
|
|
||||||
c.counter = counter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithClient overrides the default client
|
|
||||||
func WithClient(client kinesisiface.KinesisAPI) Option {
|
|
||||||
return func(c *Consumer) {
|
|
||||||
c.client = client
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ShardIteratorType overrides the starting point for the consumer
|
|
||||||
func WithShardIteratorType(t string) Option {
|
|
||||||
return func(c *Consumer) {
|
|
||||||
c.initialShardIteratorType = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanStatus signals the consumer if we should continue scanning for next record
|
|
||||||
// and whether to checkpoint.
|
|
||||||
type ScanStatus struct {
|
|
||||||
Error error
|
|
||||||
StopScan bool
|
|
||||||
SkipCheckpoint bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a kinesis consumer with default settings. Use Option to override
|
// New creates a kinesis consumer with default settings. Use Option to override
|
||||||
// any of the optional attributes.
|
// any of the optional attributes.
|
||||||
func New(streamName string, opts ...Option) (*Consumer, error) {
|
func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
|
|
@ -107,9 +62,23 @@ type Consumer struct {
|
||||||
counter Counter
|
counter Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan scans each of the shards of the stream, calls the callback
|
// ScanFunc is the type of the function called for each message read
|
||||||
// func with each of the kinesis records.
|
// from the stream. The record argument contains the original record
|
||||||
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error {
|
// returned from the AWS Kinesis library.
|
||||||
|
//
|
||||||
|
// If an error is returned, scanning stops. The sole exception is when the
|
||||||
|
// function returns the special value SkipCheckpoint.
|
||||||
|
type ScanFunc func(*Record) error
|
||||||
|
|
||||||
|
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
|
||||||
|
// the current checkpoint should be skipped skipped. It is not returned
|
||||||
|
// as an error by any function.
|
||||||
|
var SkipCheckpoint = errors.New("skip checkpoint")
|
||||||
|
|
||||||
|
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
|
||||||
|
// is passed through to each of the goroutines and called with each message pulled from
|
||||||
|
// the stream.
|
||||||
|
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|
@ -153,14 +122,10 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error
|
||||||
return <-errc
|
return <-errc
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanShard loops over records on a specific shard, calls the callback func
|
// ScanShard loops over records on a specific shard, calls the ScanFunc callback
|
||||||
// for each record and checkpoints the progress of scan.
|
// func for each record and checkpoints the progress of scan.
|
||||||
func (c *Consumer) ScanShard(
|
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||||
ctx context.Context,
|
// get last seq number from checkpoint
|
||||||
shardID string,
|
|
||||||
fn func(*Record) ScanStatus,
|
|
||||||
) error {
|
|
||||||
// get checkpoint
|
|
||||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get checkpoint error: %v", err)
|
return fmt.Errorf("get checkpoint error: %v", err)
|
||||||
|
|
@ -174,10 +139,6 @@ func (c *Consumer) ScanShard(
|
||||||
|
|
||||||
c.logger.Log("scanning", shardID, lastSeqNum)
|
c.logger.Log("scanning", shardID, lastSeqNum)
|
||||||
|
|
||||||
return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error {
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|
@ -187,6 +148,8 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
|
||||||
ShardIterator: shardIterator,
|
ShardIterator: shardIterator,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// often we can recover from GetRecords error by getting a
|
||||||
|
// new shard iterator, else return error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
|
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -195,21 +158,32 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop records of page
|
// loop over records, call callback func
|
||||||
for _, r := range resp.Records {
|
for _, r := range resp.Records {
|
||||||
isScanStopped, err := c.handleRecord(shardID, r, fn)
|
select {
|
||||||
if err != nil {
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
err := fn(r)
|
||||||
|
if err != nil && err != SkipCheckpoint {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if isScanStopped {
|
|
||||||
return nil
|
if err != SkipCheckpoint {
|
||||||
|
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.counter.Add("records", 1)
|
||||||
lastSeqNum = *r.SequenceNumber
|
lastSeqNum = *r.SequenceNumber
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
shardIterator = resp.NextShardIterator
|
shardIterator = resp.NextShardIterator
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -219,32 +193,12 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
|
||||||
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) {
|
|
||||||
status := fn(r)
|
|
||||||
|
|
||||||
if !status.SkipCheckpoint {
|
|
||||||
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := status.Error; err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.counter.Add("records", 1)
|
|
||||||
|
|
||||||
if status.StopScan {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||||
var ss []string
|
var ss []string
|
||||||
var listShardsInput = &kinesis.ListShardsInput{
|
var listShardsInput = &kinesis.ListShardsInput{
|
||||||
StreamName: aws.String(streamName),
|
StreamName: aws.String(streamName),
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
resp, err := c.client.ListShards(listShardsInput)
|
resp, err := c.client.ListShards(listShardsInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -265,22 +219,19 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
|
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
|
||||||
params := &kinesis.GetShardIteratorInput{
|
params := &kinesis.GetShardIteratorInput{
|
||||||
ShardId: aws.String(shardID),
|
ShardId: aws.String(shardID),
|
||||||
StreamName: aws.String(streamName),
|
StreamName: aws.String(streamName),
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastSeqNum != "" {
|
if seqNum != "" {
|
||||||
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
|
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
|
||||||
params.StartingSequenceNumber = aws.String(lastSeqNum)
|
params.StartingSequenceNumber = aws.String(seqNum)
|
||||||
} else {
|
} else {
|
||||||
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
|
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.client.GetShardIterator(params)
|
res, err := c.client.GetShardIterator(params)
|
||||||
if err != nil {
|
return res.ShardIterator, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return resp.ShardIterator, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
153
consumer_test.go
153
consumer_test.go
|
|
@ -11,14 +11,7 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
var records = []*kinesis.Record{
|
||||||
if _, err := New("myStreamName"); err != nil {
|
|
||||||
t.Fatalf("new consumer error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConsumer_Scan(t *testing.T) {
|
|
||||||
records := []*kinesis.Record{
|
|
||||||
{
|
{
|
||||||
Data: []byte("firstData"),
|
Data: []byte("firstData"),
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
SequenceNumber: aws.String("firstSeqNum"),
|
||||||
|
|
@ -27,7 +20,15 @@ func TestConsumer_Scan(t *testing.T) {
|
||||||
Data: []byte("lastData"),
|
Data: []byte("lastData"),
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
SequenceNumber: aws.String("lastSeqNum"),
|
||||||
},
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
if _, err := New("myStreamName"); err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScan(t *testing.T) {
|
||||||
client := &kinesisClientMock{
|
client := &kinesisClientMock{
|
||||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return &kinesis.GetShardIteratorOutput{
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
|
@ -64,10 +65,10 @@ func TestConsumer_Scan(t *testing.T) {
|
||||||
|
|
||||||
var resultData string
|
var resultData string
|
||||||
var fnCallCounter int
|
var fnCallCounter int
|
||||||
var fn = func(r *Record) ScanStatus {
|
var fn = func(r *Record) error {
|
||||||
fnCallCounter++
|
fnCallCounter++
|
||||||
resultData += string(r.Data)
|
resultData += string(r.Data)
|
||||||
return ScanStatus{}
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Scan(context.Background(), fn); err != nil {
|
if err := c.Scan(context.Background(), fn); err != nil {
|
||||||
|
|
@ -75,11 +76,13 @@ func TestConsumer_Scan(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if resultData != "firstDatalastData" {
|
if resultData != "firstDatalastData" {
|
||||||
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
|
t.Errorf("callback error expected %s, got %s", "FirstLast", resultData)
|
||||||
}
|
}
|
||||||
|
|
||||||
if fnCallCounter != 2 {
|
if fnCallCounter != 2 {
|
||||||
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if val := ctr.counter; val != 2 {
|
if val := ctr.counter; val != 2 {
|
||||||
t.Errorf("counter error expected %d, got %d", 2, val)
|
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||||
}
|
}
|
||||||
|
|
@ -90,7 +93,7 @@ func TestConsumer_Scan(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
func TestScan_NoShardsAvailable(t *testing.T) {
|
||||||
client := &kinesisClientMock{
|
client := &kinesisClientMock{
|
||||||
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
||||||
return &kinesis.ListShardsOutput{
|
return &kinesis.ListShardsOutput{
|
||||||
|
|
@ -98,54 +101,22 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
var (
|
|
||||||
cp = &fakeCheckpoint{cache: map[string]string{}}
|
|
||||||
ctr = &fakeCounter{}
|
|
||||||
)
|
|
||||||
|
|
||||||
c, err := New("myStreamName",
|
var fn = func(r *Record) error {
|
||||||
WithClient(client),
|
return nil
|
||||||
WithCounter(ctr),
|
|
||||||
WithCheckpoint(cp),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("new consumer error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var fnCallCounter int
|
c, err := New("myStreamName", WithClient(client))
|
||||||
var fn = func(r *Record) ScanStatus {
|
if err != nil {
|
||||||
fnCallCounter++
|
t.Fatalf("new consumer error: %v", err)
|
||||||
return ScanStatus{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Scan(context.Background(), fn); err == nil {
|
if err := c.Scan(context.Background(), fn); err == nil {
|
||||||
t.Errorf("scan shard error expected not nil. got %v", err)
|
t.Errorf("scan shard error expected not nil. got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if fnCallCounter != 0 {
|
|
||||||
t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter)
|
|
||||||
}
|
|
||||||
if val := ctr.counter; val != 0 {
|
|
||||||
t.Errorf("counter error expected %d, got %d", 0, val)
|
|
||||||
}
|
|
||||||
val, err := cp.Get("myStreamName", "myShard")
|
|
||||||
if err != nil && val != "" {
|
|
||||||
t.Errorf("checkout error expected %s, got %s", "", val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanShard(t *testing.T) {
|
func TestScanShard(t *testing.T) {
|
||||||
var records = []*kinesis.Record{
|
|
||||||
{
|
|
||||||
Data: []byte("firstData"),
|
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Data: []byte("lastData"),
|
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var client = &kinesisClientMock{
|
var client = &kinesisClientMock{
|
||||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return &kinesis.GetShardIteratorOutput{
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
|
@ -176,9 +147,9 @@ func TestScanShard(t *testing.T) {
|
||||||
|
|
||||||
// callback fn appends record data
|
// callback fn appends record data
|
||||||
var res string
|
var res string
|
||||||
var fn = func(r *Record) ScanStatus {
|
var fn = func(r *Record) error {
|
||||||
res += string(r.Data)
|
res += string(r.Data)
|
||||||
return ScanStatus{}
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// scan shard
|
// scan shard
|
||||||
|
|
@ -203,18 +174,7 @@ func TestScanShard(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanShard_StopScan(t *testing.T) {
|
func TestScanShard_Cancellation(t *testing.T) {
|
||||||
var records = []*kinesis.Record{
|
|
||||||
{
|
|
||||||
Data: []byte("firstData"),
|
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Data: []byte("lastData"),
|
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var client = &kinesisClientMock{
|
var client = &kinesisClientMock{
|
||||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return &kinesis.GetShardIteratorOutput{
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
|
@ -229,19 +189,23 @@ func TestScanShard_StopScan(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// use cancel func to signal shutdown
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
var res string
|
||||||
|
var fn = func(r *Record) error {
|
||||||
|
res += string(r.Data)
|
||||||
|
cancel() // simulate cancellation while processing first record
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
c, err := New("myStreamName", WithClient(client))
|
c, err := New("myStreamName", WithClient(client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// callback fn appends record data
|
err = c.ScanShard(ctx, "myShard", fn)
|
||||||
var res string
|
if err != nil {
|
||||||
var fn = func(r *Record) ScanStatus {
|
|
||||||
res += string(r.Data)
|
|
||||||
return ScanStatus{StopScan: true}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
|
||||||
t.Fatalf("scan shard error: %v", err)
|
t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -250,6 +214,46 @@ func TestScanShard_StopScan(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScanShard_SkipCheckpoint(t *testing.T) {
|
||||||
|
var client = &kinesisClientMock{
|
||||||
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: records,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var cp = &fakeCheckpoint{cache: map[string]string{}}
|
||||||
|
|
||||||
|
c, err := New("myStreamName", WithClient(client), WithCheckpoint(cp))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fn = func(r *Record) error {
|
||||||
|
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
|
||||||
|
return SkipCheckpoint
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.ScanShard(context.Background(), "myShard", fn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("scan shard error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := cp.Get("myStreamName", "myShard")
|
||||||
|
if err != nil && val != "firstSeqNum" {
|
||||||
|
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestScanShard_ShardIsClosed(t *testing.T) {
|
func TestScanShard_ShardIsClosed(t *testing.T) {
|
||||||
var client = &kinesisClientMock{
|
var client = &kinesisClientMock{
|
||||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
|
@ -270,11 +274,12 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var fn = func(r *Record) ScanStatus {
|
var fn = func(r *Record) error {
|
||||||
return ScanStatus{StopScan: true}
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
err = c.ScanShard(context.Background(), "myShard", fn)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("scan shard error: %v", err)
|
t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
app = flag.String("app", "", "App name")
|
app = flag.String("app", "", "Consumer app name")
|
||||||
stream = flag.String("stream", "", "Stream name")
|
stream = flag.String("stream", "", "Stream name")
|
||||||
table = flag.String("table", "", "Checkpoint table name")
|
table = flag.String("table", "", "Checkpoint table name")
|
||||||
)
|
)
|
||||||
|
|
@ -103,11 +103,9 @@ func main() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// scan stream
|
// scan stream
|
||||||
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
|
err = c.Scan(ctx, func(r *consumer.Record) error {
|
||||||
fmt.Println(string(r.Data))
|
fmt.Println(string(r.Data))
|
||||||
|
return nil // continue scanning
|
||||||
// continue scanning
|
|
||||||
return consumer.ScanStatus{}
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Log("scan error: %v", err)
|
log.Log("scan error: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
var (
|
||||||
app = flag.String("app", "", "App name")
|
app = flag.String("app", "", "Consumer app name")
|
||||||
stream = flag.String("stream", "", "Stream name")
|
stream = flag.String("stream", "", "Stream name")
|
||||||
table = flag.String("table", "", "Table name")
|
table = flag.String("table", "", "Table name")
|
||||||
connStr = flag.String("connection", "", "Connection Str")
|
connStr = flag.String("connection", "", "Connection Str")
|
||||||
|
|
@ -53,11 +53,9 @@ func main() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// scan stream
|
// scan stream
|
||||||
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
|
err = c.Scan(ctx, func(r *consumer.Record) error {
|
||||||
fmt.Println(string(r.Data))
|
fmt.Println(string(r.Data))
|
||||||
|
return nil // continue scanning
|
||||||
// continue scanning
|
|
||||||
return consumer.ScanStatus{}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
var (
|
||||||
app = flag.String("app", "", "App name")
|
app = flag.String("app", "", "Consumer app name")
|
||||||
stream = flag.String("stream", "", "Stream name")
|
stream = flag.String("stream", "", "Stream name")
|
||||||
)
|
)
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
@ -46,11 +46,9 @@ func main() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// scan stream
|
// scan stream
|
||||||
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
|
err = c.Scan(ctx, func(r *consumer.Record) error {
|
||||||
fmt.Println(string(r.Data))
|
fmt.Println(string(r.Data))
|
||||||
|
return nil // continue scanning
|
||||||
// continue scanning
|
|
||||||
return consumer.ScanStatus{}
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("scan error: %v", err)
|
log.Fatalf("scan error: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,12 @@ func main() {
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var records []*kinesis.PutRecordsRequestEntry
|
var records []*kinesis.PutRecordsRequestEntry
|
||||||
var client = kinesis.New(session.New())
|
|
||||||
|
sess, err := session.NewSession(aws.NewConfig())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
var client = kinesis.New(sess)
|
||||||
|
|
||||||
// loop over file data
|
// loop over file data
|
||||||
b := bufio.NewScanner(f)
|
b := bufio.NewScanner(f)
|
||||||
|
|
|
||||||
41
options.go
Normal file
41
options.go
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
package consumer
|
||||||
|
|
||||||
|
import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
|
|
||||||
|
// Option is used to override defaults when creating a new Consumer
|
||||||
|
type Option func(*Consumer)
|
||||||
|
|
||||||
|
// WithCheckpoint overrides the default checkpoint
|
||||||
|
func WithCheckpoint(checkpoint Checkpoint) Option {
|
||||||
|
return func(c *Consumer) {
|
||||||
|
c.checkpoint = checkpoint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger overrides the default logger
|
||||||
|
func WithLogger(logger Logger) Option {
|
||||||
|
return func(c *Consumer) {
|
||||||
|
c.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCounter overrides the default counter
|
||||||
|
func WithCounter(counter Counter) Option {
|
||||||
|
return func(c *Consumer) {
|
||||||
|
c.counter = counter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithClient overrides the default client
|
||||||
|
func WithClient(client kinesisiface.KinesisAPI) Option {
|
||||||
|
return func(c *Consumer) {
|
||||||
|
c.client = client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShardIteratorType overrides the starting point for the consumer
|
||||||
|
func WithShardIteratorType(t string) Option {
|
||||||
|
return func(c *Consumer) {
|
||||||
|
c.initialShardIteratorType = t
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue