Control flow with custom errors types
Major changes: * Remove the concept of `ScanStatus` in favor of custom errors Minor changes: * Move optional config to new file https://github.com/harlow/kinesis-consumer/issues/75
This commit is contained in:
parent
8fd7675ea4
commit
7d5601fbde
9 changed files with 142 additions and 155 deletions
|
|
@ -6,6 +6,14 @@ All notable changes to this project will be documented in this file.
|
|||
|
||||
Major changes:
|
||||
|
||||
* Remove the concept of `ScanStatus` to simplify the scanning interface
|
||||
|
||||
For more context on this change see: https://github.com/harlow/kinesis-consumer/issues/75
|
||||
|
||||
## v0.3.0 - 2018-12-28
|
||||
|
||||
Major changes:
|
||||
|
||||
* Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client.
|
||||
* Rename `ScanError` to `ScanStatus` as it's not always an error.
|
||||
|
||||
|
|
|
|||
34
README.md
34
README.md
|
|
@ -40,39 +40,41 @@ func main() {
|
|||
}
|
||||
|
||||
// start scan
|
||||
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanStatus {
|
||||
err = c.Scan(context.TODO(), func(r *consumer.Record) error {
|
||||
fmt.Println(string(r.Data))
|
||||
|
||||
return consumer.ScanStatus{
|
||||
StopScan: false, // true to stop scan
|
||||
SkipCheckpoint: false, // true to skip checkpoint
|
||||
}
|
||||
return nil // continue scanning
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("scan error: %v", err)
|
||||
}
|
||||
|
||||
// Note: If you need to aggregate based on a specific shard the `ScanShard`
|
||||
// method should be leverged instead.
|
||||
// Note: If you need to aggregate based on a specific shard
|
||||
// the `ScanShard` function should be used instead.
|
||||
}
|
||||
```
|
||||
|
||||
## Scan status
|
||||
## ScanFunc
|
||||
|
||||
The scan func returns a `consumer.ScanStatus` the struct allows some basic flow control.
|
||||
The `ScanFunc` receives a Kinesis Record and returns an `error`
|
||||
|
||||
```go
|
||||
type ScanFunc func(*Record) error
|
||||
```
|
||||
|
||||
Return `nil` to continue scanning, or choose from the custom error types for additional flow control.
|
||||
|
||||
```go
|
||||
// continue scanning
|
||||
return consumer.ScanStatus{}
|
||||
return nil
|
||||
|
||||
// continue scanning, skip saving checkpoint
|
||||
return consumer.ScanStatus{SkipCheckpoint: true}
|
||||
// continue scanning, skip checkpoint
|
||||
return consumer.SkipCheckpoint
|
||||
|
||||
// stop scanning, return nil
|
||||
return consumer.ScanStatus{StopScan: true}
|
||||
return consumer.StopScan
|
||||
|
||||
// stop scanning, return error
|
||||
return consumer.ScanStatus{Error: err}
|
||||
return errors.New("my error, exit all scans")
|
||||
```
|
||||
|
||||
## Checkpoint
|
||||
|
|
@ -182,7 +184,7 @@ Override the Kinesis client if there is any special config needed:
|
|||
|
||||
```go
|
||||
// client
|
||||
client := kinesis.New(session.New(aws.NewConfig()))
|
||||
client := kinesis.New(session.NewSession(aws.NewConfig()))
|
||||
|
||||
// consumer
|
||||
c, err := consumer.New(streamName, consumer.WithClient(client))
|
||||
|
|
|
|||
136
consumer.go
136
consumer.go
|
|
@ -2,6 +2,7 @@ package consumer
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
|
@ -16,52 +17,6 @@ import (
|
|||
// Record is an alias of record returned from kinesis library
|
||||
type Record = kinesis.Record
|
||||
|
||||
// Option is used to override defaults when creating a new Consumer
|
||||
type Option func(*Consumer)
|
||||
|
||||
// WithCheckpoint overrides the default checkpoint
|
||||
func WithCheckpoint(checkpoint Checkpoint) Option {
|
||||
return func(c *Consumer) {
|
||||
c.checkpoint = checkpoint
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger overrides the default logger
|
||||
func WithLogger(logger Logger) Option {
|
||||
return func(c *Consumer) {
|
||||
c.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithCounter overrides the default counter
|
||||
func WithCounter(counter Counter) Option {
|
||||
return func(c *Consumer) {
|
||||
c.counter = counter
|
||||
}
|
||||
}
|
||||
|
||||
// WithClient overrides the default client
|
||||
func WithClient(client kinesisiface.KinesisAPI) Option {
|
||||
return func(c *Consumer) {
|
||||
c.client = client
|
||||
}
|
||||
}
|
||||
|
||||
// ShardIteratorType overrides the starting point for the consumer
|
||||
func WithShardIteratorType(t string) Option {
|
||||
return func(c *Consumer) {
|
||||
c.initialShardIteratorType = t
|
||||
}
|
||||
}
|
||||
|
||||
// ScanStatus signals the consumer if we should continue scanning for next record
|
||||
// and whether to checkpoint.
|
||||
type ScanStatus struct {
|
||||
Error error
|
||||
StopScan bool
|
||||
SkipCheckpoint bool
|
||||
}
|
||||
|
||||
// New creates a kinesis consumer with default settings. Use Option to override
|
||||
// any of the optional attributes.
|
||||
func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||
|
|
@ -107,9 +62,28 @@ type Consumer struct {
|
|||
counter Counter
|
||||
}
|
||||
|
||||
// Scan scans each of the shards of the stream, calls the callback
|
||||
// func with each of the kinesis records.
|
||||
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error {
|
||||
// ScanFunc is the type of the function called for each message read
|
||||
// from the stream. The record argument contains the original record
|
||||
// returned from the AWS Kinesis library.
|
||||
//
|
||||
// If an error is returned, scanning stops. The sole exception is when the
|
||||
// function returns the special value SkipCheckpoint or StopScan.
|
||||
type ScanFunc func(*Record) error
|
||||
|
||||
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
|
||||
// the current checkpoint should be skipped skipped. It is not returned
|
||||
// as an error by any function.
|
||||
var SkipCheckpoint = errors.New("skip checkpoint")
|
||||
|
||||
// StopScan is used as a return value from ScanFuncs to indicate that
|
||||
// the we should stop scanning the current shard. It is not returned
|
||||
// as an error by any function.
|
||||
var StopScan = errors.New("stop scan")
|
||||
|
||||
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
|
||||
// is passed through to each of the goroutines and called with each message pulled from
|
||||
// the stream.
|
||||
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -153,14 +127,10 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error
|
|||
return <-errc
|
||||
}
|
||||
|
||||
// ScanShard loops over records on a specific shard, calls the callback func
|
||||
// for each record and checkpoints the progress of scan.
|
||||
func (c *Consumer) ScanShard(
|
||||
ctx context.Context,
|
||||
shardID string,
|
||||
fn func(*Record) ScanStatus,
|
||||
) error {
|
||||
// get checkpoint
|
||||
// ScanShard loops over records on a specific shard, calls the ScanFunc callback
|
||||
// func for each record and checkpoints the progress of scan.
|
||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||
// get last seq number from checkpoint
|
||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get checkpoint error: %v", err)
|
||||
|
|
@ -174,10 +144,7 @@ func (c *Consumer) ScanShard(
|
|||
|
||||
c.logger.Log("scanning", shardID, lastSeqNum)
|
||||
|
||||
return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn)
|
||||
}
|
||||
|
||||
func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error {
|
||||
// loop until
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
@ -187,6 +154,8 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
|
|||
ShardIterator: shardIterator,
|
||||
})
|
||||
|
||||
// often we can recover from GetRecords error by getting a
|
||||
// new shard iterator, else return error
|
||||
if err != nil {
|
||||
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
|
||||
if err != nil {
|
||||
|
|
@ -195,21 +164,31 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
|
|||
continue
|
||||
}
|
||||
|
||||
// loop records of page
|
||||
// call callback func with each record from response
|
||||
for _, r := range resp.Records {
|
||||
isScanStopped, err := c.handleRecord(shardID, r, fn)
|
||||
if err != nil {
|
||||
lastSeqNum = *r.SequenceNumber
|
||||
c.counter.Add("records", 1)
|
||||
|
||||
if err := fn(r); err != nil {
|
||||
switch err {
|
||||
case StopScan:
|
||||
return nil
|
||||
case SkipCheckpoint:
|
||||
continue
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
if isScanStopped {
|
||||
return nil
|
||||
}
|
||||
lastSeqNum = *r.SequenceNumber
|
||||
}
|
||||
|
||||
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||
return nil
|
||||
}
|
||||
|
||||
shardIterator = resp.NextShardIterator
|
||||
}
|
||||
}
|
||||
|
|
@ -219,27 +198,6 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
|
|||
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
||||
}
|
||||
|
||||
func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) {
|
||||
status := fn(r)
|
||||
|
||||
if !status.SkipCheckpoint {
|
||||
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := status.Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
c.counter.Add("records", 1)
|
||||
|
||||
if status.StopScan {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||
var ss []string
|
||||
var listShardsInput = &kinesis.ListShardsInput{
|
||||
|
|
|
|||
|
|
@ -64,10 +64,10 @@ func TestConsumer_Scan(t *testing.T) {
|
|||
|
||||
var resultData string
|
||||
var fnCallCounter int
|
||||
var fn = func(r *Record) ScanStatus {
|
||||
var fn = func(r *Record) error {
|
||||
fnCallCounter++
|
||||
resultData += string(r.Data)
|
||||
return ScanStatus{}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.Scan(context.Background(), fn); err != nil {
|
||||
|
|
@ -98,40 +98,19 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
}
|
||||
var (
|
||||
cp = &fakeCheckpoint{cache: map[string]string{}}
|
||||
ctr = &fakeCounter{}
|
||||
)
|
||||
|
||||
c, err := New("myStreamName",
|
||||
WithClient(client),
|
||||
WithCounter(ctr),
|
||||
WithCheckpoint(cp),
|
||||
)
|
||||
var fn = func(r *Record) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c, err := New("myStreamName", WithClient(client))
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
var fnCallCounter int
|
||||
var fn = func(r *Record) ScanStatus {
|
||||
fnCallCounter++
|
||||
return ScanStatus{}
|
||||
}
|
||||
|
||||
if err := c.Scan(context.Background(), fn); err == nil {
|
||||
t.Errorf("scan shard error expected not nil. got %v", err)
|
||||
}
|
||||
|
||||
if fnCallCounter != 0 {
|
||||
t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter)
|
||||
}
|
||||
if val := ctr.counter; val != 0 {
|
||||
t.Errorf("counter error expected %d, got %d", 0, val)
|
||||
}
|
||||
val, err := cp.Get("myStreamName", "myShard")
|
||||
if err != nil && val != "" {
|
||||
t.Errorf("checkout error expected %s, got %s", "", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanShard(t *testing.T) {
|
||||
|
|
@ -176,9 +155,9 @@ func TestScanShard(t *testing.T) {
|
|||
|
||||
// callback fn appends record data
|
||||
var res string
|
||||
var fn = func(r *Record) ScanStatus {
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
return ScanStatus{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scan shard
|
||||
|
|
@ -236,9 +215,9 @@ func TestScanShard_StopScan(t *testing.T) {
|
|||
|
||||
// callback fn appends record data
|
||||
var res string
|
||||
var fn = func(r *Record) ScanStatus {
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
return ScanStatus{StopScan: true}
|
||||
return StopScan
|
||||
}
|
||||
|
||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
||||
|
|
@ -270,8 +249,8 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
|
|||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
var fn = func(r *Record) ScanStatus {
|
||||
return ScanStatus{StopScan: true}
|
||||
var fn = func(r *Record) error {
|
||||
return StopScan
|
||||
}
|
||||
|
||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ func main() {
|
|||
}
|
||||
|
||||
var (
|
||||
app = flag.String("app", "", "App name")
|
||||
app = flag.String("app", "", "Consumer app name")
|
||||
stream = flag.String("stream", "", "Stream name")
|
||||
table = flag.String("table", "", "Checkpoint table name")
|
||||
)
|
||||
|
|
@ -103,11 +103,9 @@ func main() {
|
|||
}()
|
||||
|
||||
// scan stream
|
||||
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
|
||||
err = c.Scan(ctx, func(r *consumer.Record) error {
|
||||
fmt.Println(string(r.Data))
|
||||
|
||||
// continue scanning
|
||||
return consumer.ScanStatus{}
|
||||
return nil // continue scanning
|
||||
})
|
||||
if err != nil {
|
||||
log.Log("scan error: %v", err)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
func main() {
|
||||
var (
|
||||
app = flag.String("app", "", "App name")
|
||||
app = flag.String("app", "", "Consumer app name")
|
||||
stream = flag.String("stream", "", "Stream name")
|
||||
table = flag.String("table", "", "Table name")
|
||||
connStr = flag.String("connection", "", "Connection Str")
|
||||
|
|
@ -53,11 +53,9 @@ func main() {
|
|||
}()
|
||||
|
||||
// scan stream
|
||||
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
|
||||
err = c.Scan(ctx, func(r *consumer.Record) error {
|
||||
fmt.Println(string(r.Data))
|
||||
|
||||
// continue scanning
|
||||
return consumer.ScanStatus{}
|
||||
return nil // continue scanning
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import (
|
|||
|
||||
func main() {
|
||||
var (
|
||||
app = flag.String("app", "", "App name")
|
||||
app = flag.String("app", "", "Consumer app name")
|
||||
stream = flag.String("stream", "", "Stream name")
|
||||
)
|
||||
flag.Parse()
|
||||
|
|
@ -46,11 +46,9 @@ func main() {
|
|||
}()
|
||||
|
||||
// scan stream
|
||||
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
|
||||
err = c.Scan(ctx, func(r *consumer.Record) error {
|
||||
fmt.Println(string(r.Data))
|
||||
|
||||
// continue scanning
|
||||
return consumer.ScanStatus{}
|
||||
return nil // continue scanning
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("scan error: %v", err)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,12 @@ func main() {
|
|||
defer f.Close()
|
||||
|
||||
var records []*kinesis.PutRecordsRequestEntry
|
||||
var client = kinesis.New(session.New())
|
||||
|
||||
sess, err := session.NewSession(aws.NewConfig())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
var client = kinesis.New(sess)
|
||||
|
||||
// loop over file data
|
||||
b := bufio.NewScanner(f)
|
||||
|
|
|
|||
41
options.go
Normal file
41
options.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package consumer
|
||||
|
||||
import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||
|
||||
// Option is used to override defaults when creating a new Consumer
|
||||
type Option func(*Consumer)
|
||||
|
||||
// WithCheckpoint overrides the default checkpoint
|
||||
func WithCheckpoint(checkpoint Checkpoint) Option {
|
||||
return func(c *Consumer) {
|
||||
c.checkpoint = checkpoint
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger overrides the default logger
|
||||
func WithLogger(logger Logger) Option {
|
||||
return func(c *Consumer) {
|
||||
c.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithCounter overrides the default counter
|
||||
func WithCounter(counter Counter) Option {
|
||||
return func(c *Consumer) {
|
||||
c.counter = counter
|
||||
}
|
||||
}
|
||||
|
||||
// WithClient overrides the default client
|
||||
func WithClient(client kinesisiface.KinesisAPI) Option {
|
||||
return func(c *Consumer) {
|
||||
c.client = client
|
||||
}
|
||||
}
|
||||
|
||||
// ShardIteratorType overrides the starting point for the consumer
|
||||
func WithShardIteratorType(t string) Option {
|
||||
return func(c *Consumer) {
|
||||
c.initialShardIteratorType = t
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue