Introduce ScanFunc signature and remove ScanStatus (#77)

Major changes:

```go
type ScanFunc func(r *Record) error
```

* Simplify the callback func signature by removing `ScanStatus` 
* Leverage context for cancellation 
* Add custom error `SkipCheckpoint` for special cases when we don't want to checkpoint

Minor changes:

* Use kinesis package constants for shard iterator types
* Move optional config to new file

See conversation on #75 for more details
This commit is contained in:
Harlow Ward 2019-04-07 16:29:12 -07:00 committed by GitHub
parent 24de74fd14
commit 76158d24ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 235 additions and 208 deletions

View file

@ -6,6 +6,14 @@ All notable changes to this project will be documented in this file.
Major changes:
* Remove the concept of `ScanStatus` to simplify the scanning interface
For more context on this change see: https://github.com/harlow/kinesis-consumer/issues/75
## v0.3.0 - 2018-12-28
Major changes:
* Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client.
* Rename `ScanError` to `ScanStatus` as it's not always an error.

View file

@ -40,39 +40,62 @@ func main() {
}
// start scan
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanStatus {
err = c.Scan(context.TODO(), func(r *consumer.Record) error {
fmt.Println(string(r.Data))
return consumer.ScanStatus{
StopScan: false, // true to stop scan
SkipCheckpoint: false, // true to skip checkpoint
}
return nil // continue scanning
})
if err != nil {
log.Fatalf("scan error: %v", err)
}
// Note: If you need to aggregate based on a specific shard the `ScanShard`
// method should be leverged instead.
// Note: If you need to aggregate based on a specific shard
// the `ScanShard` function should be used instead.
}
```
## Scan status
## ScanFunc
The scan func returns a `consumer.ScanStatus` the struct allows some basic flow control.
ScanFunc is the type of the function called for each message read
from the stream. The record argument contains the original record
returned from the AWS Kinesis library.
```go
type ScanFunc func(r *Record) error
```
If an error is returned, scanning stops. The sole exception is when the
function returns the special value SkipCheckpoint.
```go
// continue scanning
return consumer.ScanStatus{}
return nil
// continue scanning, skip saving checkpoint
return consumer.ScanStatus{SkipCheckpoint: true}
// stop scanning, return nil
return consumer.ScanStatus{StopScan: true}
// continue scanning, skip checkpoint
return consumer.SkipCheckpoint
// stop scanning, return error
return consumer.ScanStatus{Error: err}
return errors.New("my error, exit all scans")
```
Use context cancel to signal the scan to exit without error. For example if we wanted to gracefulloy exit the scan on interrupt.
```go
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
// context with cancel
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-signals
cancel() // call cancellation
}()
err := c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data))
return nil // continue scanning
})
```
## Checkpoint
@ -182,7 +205,7 @@ Override the Kinesis client if there is any special config needed:
```go
// client
client := kinesis.New(session.New(aws.NewConfig()))
client := kinesis.New(session.NewSession(aws.NewConfig()))
// consumer
c, err := consumer.New(streamName, consumer.WithClient(client))

View file

@ -2,6 +2,7 @@ package consumer
import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
@ -16,52 +17,6 @@ import (
// Record is an alias of record returned from kinesis library
type Record = kinesis.Record
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer)
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint Checkpoint) Option {
return func(c *Consumer) {
c.checkpoint = checkpoint
}
}
// WithLogger overrides the default logger
func WithLogger(logger Logger) Option {
return func(c *Consumer) {
c.logger = logger
}
}
// WithCounter overrides the default counter
func WithCounter(counter Counter) Option {
return func(c *Consumer) {
c.counter = counter
}
}
// WithClient overrides the default client
func WithClient(client kinesisiface.KinesisAPI) Option {
return func(c *Consumer) {
c.client = client
}
}
// ShardIteratorType overrides the starting point for the consumer
func WithShardIteratorType(t string) Option {
return func(c *Consumer) {
c.initialShardIteratorType = t
}
}
// ScanStatus signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanStatus struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(streamName string, opts ...Option) (*Consumer, error) {
@ -107,9 +62,23 @@ type Consumer struct {
counter Counter
}
// Scan scans each of the shards of the stream, calls the callback
// func with each of the kinesis records.
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error {
// ScanFunc is the type of the function called for each message read
// from the stream. The record argument contains the original record
// returned from the AWS Kinesis library.
//
// If an error is returned, scanning stops. The sole exception is when the
// function returns the special value SkipCheckpoint.
type ScanFunc func(*Record) error
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
// the current checkpoint should be skipped skipped. It is not returned
// as an error by any function.
var SkipCheckpoint = errors.New("skip checkpoint")
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
// is passed through to each of the goroutines and called with each message pulled from
// the stream.
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@ -153,14 +122,10 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error
return <-errc
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(
ctx context.Context,
shardID string,
fn func(*Record) ScanStatus,
) error {
// get checkpoint
// ScanShard loops over records on a specific shard, calls the ScanFunc callback
// func for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
// get last seq number from checkpoint
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil {
return fmt.Errorf("get checkpoint error: %v", err)
@ -174,10 +139,6 @@ func (c *Consumer) ScanShard(
c.logger.Log("scanning", shardID, lastSeqNum)
return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn)
}
func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error {
for {
select {
case <-ctx.Done():
@ -187,6 +148,8 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
ShardIterator: shardIterator,
})
// often we can recover from GetRecords error by getting a
// new shard iterator, else return error
if err != nil {
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil {
@ -195,21 +158,32 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
continue
}
// loop records of page
// loop over records, call callback func
for _, r := range resp.Records {
isScanStopped, err := c.handleRecord(shardID, r, fn)
if err != nil {
return err
}
if isScanStopped {
select {
case <-ctx.Done():
return nil
default:
err := fn(r)
if err != nil && err != SkipCheckpoint {
return err
}
if err != SkipCheckpoint {
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err
}
}
c.counter.Add("records", 1)
lastSeqNum = *r.SequenceNumber
}
lastSeqNum = *r.SequenceNumber
}
if isShardClosed(resp.NextShardIterator, shardIterator) {
return nil
}
shardIterator = resp.NextShardIterator
}
}
@ -219,32 +193,12 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}
func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) {
status := fn(r)
if !status.SkipCheckpoint {
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return false, err
}
}
if err := status.Error; err != nil {
return false, err
}
c.counter.Add("records", 1)
if status.StopScan {
return true, nil
}
return false, nil
}
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
var ss []string
var listShardsInput = &kinesis.ListShardsInput{
StreamName: aws.String(streamName),
}
for {
resp, err := c.client.ListShards(listShardsInput)
if err != nil {
@ -265,22 +219,19 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
}
}
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if lastSeqNum != "" {
if seqNum != "" {
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
params.StartingSequenceNumber = aws.String(lastSeqNum)
params.StartingSequenceNumber = aws.String(seqNum)
} else {
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
}
resp, err := c.client.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
res, err := c.client.GetShardIterator(params)
return res.ShardIterator, err
}

View file

@ -11,23 +11,24 @@ import (
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
func TestNew(t *testing.T) {
if _, err := New("myStreamName"); err != nil {
t.Fatalf("new consumer error: %v", err)
}
}
func TestConsumer_Scan(t *testing.T) {
records := []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
func TestScan(t *testing.T) {
client := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
@ -64,10 +65,10 @@ func TestConsumer_Scan(t *testing.T) {
var resultData string
var fnCallCounter int
var fn = func(r *Record) ScanStatus {
var fn = func(r *Record) error {
fnCallCounter++
resultData += string(r.Data)
return ScanStatus{}
return nil
}
if err := c.Scan(context.Background(), fn); err != nil {
@ -75,11 +76,13 @@ func TestConsumer_Scan(t *testing.T) {
}
if resultData != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
t.Errorf("callback error expected %s, got %s", "FirstLast", resultData)
}
if fnCallCounter != 2 {
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
}
if val := ctr.counter; val != 2 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
@ -90,7 +93,7 @@ func TestConsumer_Scan(t *testing.T) {
}
}
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
func TestScan_NoShardsAvailable(t *testing.T) {
client := &kinesisClientMock{
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
@ -98,54 +101,22 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
}, nil
},
}
var (
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithCheckpoint(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
var fn = func(r *Record) error {
return nil
}
var fnCallCounter int
var fn = func(r *Record) ScanStatus {
fnCallCounter++
return ScanStatus{}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
if err := c.Scan(context.Background(), fn); err == nil {
t.Errorf("scan shard error expected not nil. got %v", err)
}
if fnCallCounter != 0 {
t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter)
}
if val := ctr.counter; val != 0 {
t.Errorf("counter error expected %d, got %d", 0, val)
}
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "" {
t.Errorf("checkout error expected %s, got %s", "", val)
}
}
func TestScanShard(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
@ -176,9 +147,9 @@ func TestScanShard(t *testing.T) {
// callback fn appends record data
var res string
var fn = func(r *Record) ScanStatus {
var fn = func(r *Record) error {
res += string(r.Data)
return ScanStatus{}
return nil
}
// scan shard
@ -203,18 +174,7 @@ func TestScanShard(t *testing.T) {
}
}
func TestScanShard_StopScan(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
func TestScanShard_Cancellation(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
@ -229,19 +189,23 @@ func TestScanShard_StopScan(t *testing.T) {
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
// callback fn appends record data
var res string
var fn = func(r *Record) ScanStatus {
res += string(r.Data)
return ScanStatus{StopScan: true}
}
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
err = c.ScanShard(ctx, "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
@ -250,6 +214,46 @@ func TestScanShard_StopScan(t *testing.T) {
}
}
func TestScanShard_SkipCheckpoint(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
}
var cp = &fakeCheckpoint{cache: map[string]string{}}
c, err := New("myStreamName", WithClient(client), WithCheckpoint(cp))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var fn = func(r *Record) error {
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
return SkipCheckpoint
}
return nil
}
err = c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "firstSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
}
}
func TestScanShard_ShardIsClosed(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
@ -270,11 +274,12 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
t.Fatalf("new consumer error: %v", err)
}
var fn = func(r *Record) ScanStatus {
return ScanStatus{StopScan: true}
var fn = func(r *Record) error {
return nil
}
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
err = c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
}

View file

@ -54,7 +54,7 @@ func main() {
}
var (
app = flag.String("app", "", "App name")
app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name")
table = flag.String("table", "", "Checkpoint table name")
)
@ -103,11 +103,9 @@ func main() {
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
err = c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanStatus{}
return nil // continue scanning
})
if err != nil {
log.Log("scan error: %v", err)

View file

@ -15,7 +15,7 @@ import (
func main() {
var (
app = flag.String("app", "", "App name")
app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name")
table = flag.String("table", "", "Table name")
connStr = flag.String("connection", "", "Connection Str")
@ -53,11 +53,9 @@ func main() {
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
err = c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanStatus{}
return nil // continue scanning
})
if err != nil {

View file

@ -14,7 +14,7 @@ import (
func main() {
var (
app = flag.String("app", "", "App name")
app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name")
)
flag.Parse()
@ -46,11 +46,9 @@ func main() {
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
err = c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanStatus{}
return nil // continue scanning
})
if err != nil {
log.Fatalf("scan error: %v", err)

View file

@ -25,7 +25,12 @@ func main() {
defer f.Close()
var records []*kinesis.PutRecordsRequestEntry
var client = kinesis.New(session.New())
sess, err := session.NewSession(aws.NewConfig())
if err != nil {
log.Fatal(err)
}
var client = kinesis.New(sess)
// loop over file data
b := bufio.NewScanner(f)

41
options.go Normal file
View file

@ -0,0 +1,41 @@
package consumer
import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer)
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint Checkpoint) Option {
return func(c *Consumer) {
c.checkpoint = checkpoint
}
}
// WithLogger overrides the default logger
func WithLogger(logger Logger) Option {
return func(c *Consumer) {
c.logger = logger
}
}
// WithCounter overrides the default counter
func WithCounter(counter Counter) Option {
return func(c *Consumer) {
c.counter = counter
}
}
// WithClient overrides the default client
func WithClient(client kinesisiface.KinesisAPI) Option {
return func(c *Consumer) {
c.client = client
}
}
// ShardIteratorType overrides the starting point for the consumer
func WithShardIteratorType(t string) Option {
return func(c *Consumer) {
c.initialShardIteratorType = t
}
}