Control flow with custom errors types

Major changes:

* Remove the concept of `ScanStatus` in favor of custom errors

Minor changes:

* Move optional config to new file

https://github.com/harlow/kinesis-consumer/issues/75
This commit is contained in:
Harlow Ward 2018-12-29 20:54:39 -08:00
parent 8fd7675ea4
commit 7d5601fbde
9 changed files with 142 additions and 155 deletions

View file

@ -6,6 +6,14 @@ All notable changes to this project will be documented in this file.
Major changes: Major changes:
* Remove the concept of `ScanStatus` to simplify the scanning interface
For more context on this change see: https://github.com/harlow/kinesis-consumer/issues/75
## v0.3.0 - 2018-12-28
Major changes:
* Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client. * Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client.
* Rename `ScanError` to `ScanStatus` as it's not always an error. * Rename `ScanError` to `ScanStatus` as it's not always an error.

View file

@ -40,39 +40,41 @@ func main() {
} }
// start scan // start scan
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanStatus { err = c.Scan(context.TODO(), func(r *consumer.Record) error {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
return nil // continue scanning
return consumer.ScanStatus{
StopScan: false, // true to stop scan
SkipCheckpoint: false, // true to skip checkpoint
}
}) })
if err != nil { if err != nil {
log.Fatalf("scan error: %v", err) log.Fatalf("scan error: %v", err)
} }
// Note: If you need to aggregate based on a specific shard the `ScanShard` // Note: If you need to aggregate based on a specific shard
// method should be leverged instead. // the `ScanShard` function should be used instead.
} }
``` ```
## Scan status ## ScanFunc
The scan func returns a `consumer.ScanStatus` the struct allows some basic flow control. The `ScanFunc` receives a Kinesis Record and returns an `error`
```go
type ScanFunc func(*Record) error
```
Return `nil` to continue scanning, or choose from the custom error types for additional flow control.
```go ```go
// continue scanning // continue scanning
return consumer.ScanStatus{} return nil
// continue scanning, skip saving checkpoint // continue scanning, skip checkpoint
return consumer.ScanStatus{SkipCheckpoint: true} return consumer.SkipCheckpoint
// stop scanning, return nil // stop scanning, return nil
return consumer.ScanStatus{StopScan: true} return consumer.StopScan
// stop scanning, return error // stop scanning, return error
return consumer.ScanStatus{Error: err} return errors.New("my error, exit all scans")
``` ```
## Checkpoint ## Checkpoint
@ -182,7 +184,7 @@ Override the Kinesis client if there is any special config needed:
```go ```go
// client // client
client := kinesis.New(session.New(aws.NewConfig())) client := kinesis.New(session.NewSession(aws.NewConfig()))
// consumer // consumer
c, err := consumer.New(streamName, consumer.WithClient(client)) c, err := consumer.New(streamName, consumer.WithClient(client))

View file

@ -2,6 +2,7 @@ package consumer
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -16,52 +17,6 @@ import (
// Record is an alias of record returned from kinesis library // Record is an alias of record returned from kinesis library
type Record = kinesis.Record type Record = kinesis.Record
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer)
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint Checkpoint) Option {
return func(c *Consumer) {
c.checkpoint = checkpoint
}
}
// WithLogger overrides the default logger
func WithLogger(logger Logger) Option {
return func(c *Consumer) {
c.logger = logger
}
}
// WithCounter overrides the default counter
func WithCounter(counter Counter) Option {
return func(c *Consumer) {
c.counter = counter
}
}
// WithClient overrides the default client
func WithClient(client kinesisiface.KinesisAPI) Option {
return func(c *Consumer) {
c.client = client
}
}
// ShardIteratorType overrides the starting point for the consumer
func WithShardIteratorType(t string) Option {
return func(c *Consumer) {
c.initialShardIteratorType = t
}
}
// ScanStatus signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanStatus struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// New creates a kinesis consumer with default settings. Use Option to override // New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes. // any of the optional attributes.
func New(streamName string, opts ...Option) (*Consumer, error) { func New(streamName string, opts ...Option) (*Consumer, error) {
@ -107,9 +62,28 @@ type Consumer struct {
counter Counter counter Counter
} }
// Scan scans each of the shards of the stream, calls the callback // ScanFunc is the type of the function called for each message read
// func with each of the kinesis records. // from the stream. The record argument contains the original record
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error { // returned from the AWS Kinesis library.
//
// If an error is returned, scanning stops. The sole exception is when the
// function returns the special value SkipCheckpoint or StopScan.
type ScanFunc func(*Record) error
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
// the current checkpoint should be skipped skipped. It is not returned
// as an error by any function.
var SkipCheckpoint = errors.New("skip checkpoint")
// StopScan is used as a return value from ScanFuncs to indicate that
// the we should stop scanning the current shard. It is not returned
// as an error by any function.
var StopScan = errors.New("stop scan")
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
// is passed through to each of the goroutines and called with each message pulled from
// the stream.
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@ -153,14 +127,10 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error
return <-errc return <-errc
} }
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the ScanFunc callback
// for each record and checkpoints the progress of scan. // func for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard( func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
ctx context.Context, // get last seq number from checkpoint
shardID string,
fn func(*Record) ScanStatus,
) error {
// get checkpoint
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil { if err != nil {
return fmt.Errorf("get checkpoint error: %v", err) return fmt.Errorf("get checkpoint error: %v", err)
@ -174,10 +144,7 @@ func (c *Consumer) ScanShard(
c.logger.Log("scanning", shardID, lastSeqNum) c.logger.Log("scanning", shardID, lastSeqNum)
return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn) // loop until
}
func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -187,6 +154,8 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
ShardIterator: shardIterator, ShardIterator: shardIterator,
}) })
// often we can recover from GetRecords error by getting a
// new shard iterator, else return error
if err != nil { if err != nil {
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil { if err != nil {
@ -195,21 +164,31 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
continue continue
} }
// loop records of page // call callback func with each record from response
for _, r := range resp.Records { for _, r := range resp.Records {
isScanStopped, err := c.handleRecord(shardID, r, fn) lastSeqNum = *r.SequenceNumber
if err != nil { c.counter.Add("records", 1)
if err := fn(r); err != nil {
switch err {
case StopScan:
return nil
case SkipCheckpoint:
continue
default:
return err
}
}
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err return err
} }
if isScanStopped {
return nil
}
lastSeqNum = *r.SequenceNumber
} }
if isShardClosed(resp.NextShardIterator, shardIterator) { if isShardClosed(resp.NextShardIterator, shardIterator) {
return nil return nil
} }
shardIterator = resp.NextShardIterator shardIterator = resp.NextShardIterator
} }
} }
@ -219,27 +198,6 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator return nextShardIterator == nil || currentShardIterator == nextShardIterator
} }
func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) {
status := fn(r)
if !status.SkipCheckpoint {
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return false, err
}
}
if err := status.Error; err != nil {
return false, err
}
c.counter.Add("records", 1)
if status.StopScan {
return true, nil
}
return false, nil
}
func (c *Consumer) getShardIDs(streamName string) ([]string, error) { func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
var ss []string var ss []string
var listShardsInput = &kinesis.ListShardsInput{ var listShardsInput = &kinesis.ListShardsInput{

View file

@ -64,10 +64,10 @@ func TestConsumer_Scan(t *testing.T) {
var resultData string var resultData string
var fnCallCounter int var fnCallCounter int
var fn = func(r *Record) ScanStatus { var fn = func(r *Record) error {
fnCallCounter++ fnCallCounter++
resultData += string(r.Data) resultData += string(r.Data)
return ScanStatus{} return nil
} }
if err := c.Scan(context.Background(), fn); err != nil { if err := c.Scan(context.Background(), fn); err != nil {
@ -98,40 +98,19 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
}, nil }, nil
}, },
} }
var (
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c, err := New("myStreamName", var fn = func(r *Record) error {
WithClient(client), return nil
WithCounter(ctr), }
WithCheckpoint(cp),
) c, err := New("myStreamName", WithClient(client))
if err != nil { if err != nil {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
var fnCallCounter int
var fn = func(r *Record) ScanStatus {
fnCallCounter++
return ScanStatus{}
}
if err := c.Scan(context.Background(), fn); err == nil { if err := c.Scan(context.Background(), fn); err == nil {
t.Errorf("scan shard error expected not nil. got %v", err) t.Errorf("scan shard error expected not nil. got %v", err)
} }
if fnCallCounter != 0 {
t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter)
}
if val := ctr.counter; val != 0 {
t.Errorf("counter error expected %d, got %d", 0, val)
}
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "" {
t.Errorf("checkout error expected %s, got %s", "", val)
}
} }
func TestScanShard(t *testing.T) { func TestScanShard(t *testing.T) {
@ -176,9 +155,9 @@ func TestScanShard(t *testing.T) {
// callback fn appends record data // callback fn appends record data
var res string var res string
var fn = func(r *Record) ScanStatus { var fn = func(r *Record) error {
res += string(r.Data) res += string(r.Data)
return ScanStatus{} return nil
} }
// scan shard // scan shard
@ -236,9 +215,9 @@ func TestScanShard_StopScan(t *testing.T) {
// callback fn appends record data // callback fn appends record data
var res string var res string
var fn = func(r *Record) ScanStatus { var fn = func(r *Record) error {
res += string(r.Data) res += string(r.Data)
return ScanStatus{StopScan: true} return StopScan
} }
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
@ -270,8 +249,8 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
var fn = func(r *Record) ScanStatus { var fn = func(r *Record) error {
return ScanStatus{StopScan: true} return StopScan
} }
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {

View file

@ -54,7 +54,7 @@ func main() {
} }
var ( var (
app = flag.String("app", "", "App name") app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name") stream = flag.String("stream", "", "Stream name")
table = flag.String("table", "", "Checkpoint table name") table = flag.String("table", "", "Checkpoint table name")
) )
@ -103,11 +103,9 @@ func main() {
}() }()
// scan stream // scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus { err = c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
return nil // continue scanning
// continue scanning
return consumer.ScanStatus{}
}) })
if err != nil { if err != nil {
log.Log("scan error: %v", err) log.Log("scan error: %v", err)

View file

@ -15,7 +15,7 @@ import (
func main() { func main() {
var ( var (
app = flag.String("app", "", "App name") app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name") stream = flag.String("stream", "", "Stream name")
table = flag.String("table", "", "Table name") table = flag.String("table", "", "Table name")
connStr = flag.String("connection", "", "Connection Str") connStr = flag.String("connection", "", "Connection Str")
@ -53,11 +53,9 @@ func main() {
}() }()
// scan stream // scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus { err = c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
return nil // continue scanning
// continue scanning
return consumer.ScanStatus{}
}) })
if err != nil { if err != nil {

View file

@ -14,7 +14,7 @@ import (
func main() { func main() {
var ( var (
app = flag.String("app", "", "App name") app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name") stream = flag.String("stream", "", "Stream name")
) )
flag.Parse() flag.Parse()
@ -46,11 +46,9 @@ func main() {
}() }()
// scan stream // scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus { err = c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
return nil // continue scanning
// continue scanning
return consumer.ScanStatus{}
}) })
if err != nil { if err != nil {
log.Fatalf("scan error: %v", err) log.Fatalf("scan error: %v", err)

View file

@ -25,7 +25,12 @@ func main() {
defer f.Close() defer f.Close()
var records []*kinesis.PutRecordsRequestEntry var records []*kinesis.PutRecordsRequestEntry
var client = kinesis.New(session.New())
sess, err := session.NewSession(aws.NewConfig())
if err != nil {
log.Fatal(err)
}
var client = kinesis.New(sess)
// loop over file data // loop over file data
b := bufio.NewScanner(f) b := bufio.NewScanner(f)

41
options.go Normal file
View file

@ -0,0 +1,41 @@
package consumer
import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer)
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint Checkpoint) Option {
return func(c *Consumer) {
c.checkpoint = checkpoint
}
}
// WithLogger overrides the default logger
func WithLogger(logger Logger) Option {
return func(c *Consumer) {
c.logger = logger
}
}
// WithCounter overrides the default counter
func WithCounter(counter Counter) Option {
return func(c *Consumer) {
c.counter = counter
}
}
// WithClient overrides the default client
func WithClient(client kinesisiface.KinesisAPI) Option {
return func(c *Consumer) {
c.client = client
}
}
// ShardIteratorType overrides the starting point for the consumer
func WithShardIteratorType(t string) Option {
return func(c *Consumer) {
c.initialShardIteratorType = t
}
}