diff --git a/CHANGELOG.md b/CHANGELOG.md index a943608..f6cd94f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,14 @@ All notable changes to this project will be documented in this file. Major changes: +* Remove the concept of `ScanStatus` to simplify the scanning interface + +For more context on this change see: https://github.com/harlow/kinesis-consumer/issues/75 + +## v0.3.0 - 2018-12-28 + +Major changes: + * Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client. * Rename `ScanError` to `ScanStatus` as it's not always an error. diff --git a/README.md b/README.md index 757cf56..4ae70fa 100644 --- a/README.md +++ b/README.md @@ -40,39 +40,62 @@ func main() { } // start scan - err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanStatus { + err = c.Scan(context.TODO(), func(r *consumer.Record) error { fmt.Println(string(r.Data)) - - return consumer.ScanStatus{ - StopScan: false, // true to stop scan - SkipCheckpoint: false, // true to skip checkpoint - } + return nil // continue scanning }) if err != nil { log.Fatalf("scan error: %v", err) } - // Note: If you need to aggregate based on a specific shard the `ScanShard` - // method should be leverged instead. + // Note: If you need to aggregate based on a specific shard + // the `ScanShard` function should be used instead. } ``` -## Scan status +## ScanFunc -The scan func returns a `consumer.ScanStatus` the struct allows some basic flow control. +ScanFunc is the type of the function called for each message read +from the stream. The record argument contains the original record +returned from the AWS Kinesis library. + +```go +type ScanFunc func(r *Record) error +``` + +If an error is returned, scanning stops. The sole exception is when the +function returns the special value SkipCheckpoint. ```go // continue scanning -return consumer.ScanStatus{} +return nil -// continue scanning, skip saving checkpoint -return consumer.ScanStatus{SkipCheckpoint: true} - -// stop scanning, return nil -return consumer.ScanStatus{StopScan: true} +// continue scanning, skip checkpoint +return consumer.SkipCheckpoint // stop scanning, return error -return consumer.ScanStatus{Error: err} +return errors.New("my error, exit all scans") +``` + +Use context cancel to signal the scan to exit without error. For example if we wanted to gracefulloy exit the scan on interrupt. + +```go +// trap SIGINT, wait to trigger shutdown +signals := make(chan os.Signal, 1) +signal.Notify(signals, os.Interrupt) + +// context with cancel +ctx, cancel := context.WithCancel(context.Background()) + +go func() { + <-signals + cancel() // call cancellation +}() + +err := c.Scan(ctx, func(r *consumer.Record) error { + fmt.Println(string(r.Data)) + return nil // continue scanning +}) ``` ## Checkpoint @@ -182,7 +205,7 @@ Override the Kinesis client if there is any special config needed: ```go // client -client := kinesis.New(session.New(aws.NewConfig())) +client := kinesis.New(session.NewSession(aws.NewConfig())) // consumer c, err := consumer.New(streamName, consumer.WithClient(client)) diff --git a/consumer.go b/consumer.go index fb1cbd4..a60d378 100644 --- a/consumer.go +++ b/consumer.go @@ -2,6 +2,7 @@ package consumer import ( "context" + "errors" "fmt" "io/ioutil" "log" @@ -16,52 +17,6 @@ import ( // Record is an alias of record returned from kinesis library type Record = kinesis.Record -// Option is used to override defaults when creating a new Consumer -type Option func(*Consumer) - -// WithCheckpoint overrides the default checkpoint -func WithCheckpoint(checkpoint Checkpoint) Option { - return func(c *Consumer) { - c.checkpoint = checkpoint - } -} - -// WithLogger overrides the default logger -func WithLogger(logger Logger) Option { - return func(c *Consumer) { - c.logger = logger - } -} - -// WithCounter overrides the default counter -func WithCounter(counter Counter) Option { - return func(c *Consumer) { - c.counter = counter - } -} - -// WithClient overrides the default client -func WithClient(client kinesisiface.KinesisAPI) Option { - return func(c *Consumer) { - c.client = client - } -} - -// ShardIteratorType overrides the starting point for the consumer -func WithShardIteratorType(t string) Option { - return func(c *Consumer) { - c.initialShardIteratorType = t - } -} - -// ScanStatus signals the consumer if we should continue scanning for next record -// and whether to checkpoint. -type ScanStatus struct { - Error error - StopScan bool - SkipCheckpoint bool -} - // New creates a kinesis consumer with default settings. Use Option to override // any of the optional attributes. func New(streamName string, opts ...Option) (*Consumer, error) { @@ -107,9 +62,23 @@ type Consumer struct { counter Counter } -// Scan scans each of the shards of the stream, calls the callback -// func with each of the kinesis records. -func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error { +// ScanFunc is the type of the function called for each message read +// from the stream. The record argument contains the original record +// returned from the AWS Kinesis library. +// +// If an error is returned, scanning stops. The sole exception is when the +// function returns the special value SkipCheckpoint. +type ScanFunc func(*Record) error + +// SkipCheckpoint is used as a return value from ScanFuncs to indicate that +// the current checkpoint should be skipped skipped. It is not returned +// as an error by any function. +var SkipCheckpoint = errors.New("skip checkpoint") + +// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc +// is passed through to each of the goroutines and called with each message pulled from +// the stream. +func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -153,14 +122,10 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error return <-errc } -// ScanShard loops over records on a specific shard, calls the callback func -// for each record and checkpoints the progress of scan. -func (c *Consumer) ScanShard( - ctx context.Context, - shardID string, - fn func(*Record) ScanStatus, -) error { - // get checkpoint +// ScanShard loops over records on a specific shard, calls the ScanFunc callback +// func for each record and checkpoints the progress of scan. +func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { + // get last seq number from checkpoint lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) if err != nil { return fmt.Errorf("get checkpoint error: %v", err) @@ -174,10 +139,6 @@ func (c *Consumer) ScanShard( c.logger.Log("scanning", shardID, lastSeqNum) - return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn) -} - -func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error { for { select { case <-ctx.Done(): @@ -187,6 +148,8 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str ShardIterator: shardIterator, }) + // often we can recover from GetRecords error by getting a + // new shard iterator, else return error if err != nil { shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) if err != nil { @@ -195,21 +158,32 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str continue } - // loop records of page + // loop over records, call callback func for _, r := range resp.Records { - isScanStopped, err := c.handleRecord(shardID, r, fn) - if err != nil { - return err - } - if isScanStopped { + select { + case <-ctx.Done(): return nil + default: + err := fn(r) + if err != nil && err != SkipCheckpoint { + return err + } + + if err != SkipCheckpoint { + if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { + return err + } + } + + c.counter.Add("records", 1) + lastSeqNum = *r.SequenceNumber } - lastSeqNum = *r.SequenceNumber } if isShardClosed(resp.NextShardIterator, shardIterator) { return nil } + shardIterator = resp.NextShardIterator } } @@ -219,32 +193,12 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } -func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) { - status := fn(r) - - if !status.SkipCheckpoint { - if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { - return false, err - } - } - - if err := status.Error; err != nil { - return false, err - } - - c.counter.Add("records", 1) - - if status.StopScan { - return true, nil - } - return false, nil -} - func (c *Consumer) getShardIDs(streamName string) ([]string, error) { var ss []string var listShardsInput = &kinesis.ListShardsInput{ StreamName: aws.String(streamName), } + for { resp, err := c.client.ListShards(listShardsInput) if err != nil { @@ -265,22 +219,19 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) { } } -func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) { +func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) { params := &kinesis.GetShardIteratorInput{ ShardId: aws.String(shardID), StreamName: aws.String(streamName), } - if lastSeqNum != "" { + if seqNum != "" { params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber) - params.StartingSequenceNumber = aws.String(lastSeqNum) + params.StartingSequenceNumber = aws.String(seqNum) } else { params.ShardIteratorType = aws.String(c.initialShardIteratorType) } - resp, err := c.client.GetShardIterator(params) - if err != nil { - return nil, err - } - return resp.ShardIterator, nil + res, err := c.client.GetShardIterator(params) + return res.ShardIterator, err } diff --git a/consumer_test.go b/consumer_test.go index f669381..4712e3c 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -11,23 +11,24 @@ import ( "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" ) +var records = []*kinesis.Record{ + { + Data: []byte("firstData"), + SequenceNumber: aws.String("firstSeqNum"), + }, + { + Data: []byte("lastData"), + SequenceNumber: aws.String("lastSeqNum"), + }, +} + func TestNew(t *testing.T) { if _, err := New("myStreamName"); err != nil { t.Fatalf("new consumer error: %v", err) } } -func TestConsumer_Scan(t *testing.T) { - records := []*kinesis.Record{ - { - Data: []byte("firstData"), - SequenceNumber: aws.String("firstSeqNum"), - }, - { - Data: []byte("lastData"), - SequenceNumber: aws.String("lastSeqNum"), - }, - } +func TestScan(t *testing.T) { client := &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ @@ -64,10 +65,10 @@ func TestConsumer_Scan(t *testing.T) { var resultData string var fnCallCounter int - var fn = func(r *Record) ScanStatus { + var fn = func(r *Record) error { fnCallCounter++ resultData += string(r.Data) - return ScanStatus{} + return nil } if err := c.Scan(context.Background(), fn); err != nil { @@ -75,11 +76,13 @@ func TestConsumer_Scan(t *testing.T) { } if resultData != "firstDatalastData" { - t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData) + t.Errorf("callback error expected %s, got %s", "FirstLast", resultData) } + if fnCallCounter != 2 { t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter) } + if val := ctr.counter; val != 2 { t.Errorf("counter error expected %d, got %d", 2, val) } @@ -90,7 +93,7 @@ func TestConsumer_Scan(t *testing.T) { } } -func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { +func TestScan_NoShardsAvailable(t *testing.T) { client := &kinesisClientMock{ listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { return &kinesis.ListShardsOutput{ @@ -98,54 +101,22 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { }, nil }, } - var ( - cp = &fakeCheckpoint{cache: map[string]string{}} - ctr = &fakeCounter{} - ) - c, err := New("myStreamName", - WithClient(client), - WithCounter(ctr), - WithCheckpoint(cp), - ) - if err != nil { - t.Fatalf("new consumer error: %v", err) + var fn = func(r *Record) error { + return nil } - var fnCallCounter int - var fn = func(r *Record) ScanStatus { - fnCallCounter++ - return ScanStatus{} + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) } if err := c.Scan(context.Background(), fn); err == nil { t.Errorf("scan shard error expected not nil. got %v", err) } - - if fnCallCounter != 0 { - t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter) - } - if val := ctr.counter; val != 0 { - t.Errorf("counter error expected %d, got %d", 0, val) - } - val, err := cp.Get("myStreamName", "myShard") - if err != nil && val != "" { - t.Errorf("checkout error expected %s, got %s", "", val) - } } func TestScanShard(t *testing.T) { - var records = []*kinesis.Record{ - { - Data: []byte("firstData"), - SequenceNumber: aws.String("firstSeqNum"), - }, - { - Data: []byte("lastData"), - SequenceNumber: aws.String("lastSeqNum"), - }, - } - var client = &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ @@ -176,9 +147,9 @@ func TestScanShard(t *testing.T) { // callback fn appends record data var res string - var fn = func(r *Record) ScanStatus { + var fn = func(r *Record) error { res += string(r.Data) - return ScanStatus{} + return nil } // scan shard @@ -203,18 +174,7 @@ func TestScanShard(t *testing.T) { } } -func TestScanShard_StopScan(t *testing.T) { - var records = []*kinesis.Record{ - { - Data: []byte("firstData"), - SequenceNumber: aws.String("firstSeqNum"), - }, - { - Data: []byte("lastData"), - SequenceNumber: aws.String("lastSeqNum"), - }, - } - +func TestScanShard_Cancellation(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ @@ -229,19 +189,23 @@ func TestScanShard_StopScan(t *testing.T) { }, } + // use cancel func to signal shutdown + ctx, cancel := context.WithCancel(context.Background()) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + c, err := New("myStreamName", WithClient(client)) if err != nil { t.Fatalf("new consumer error: %v", err) } - // callback fn appends record data - var res string - var fn = func(r *Record) ScanStatus { - res += string(r.Data) - return ScanStatus{StopScan: true} - } - - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { + err = c.ScanShard(ctx, "myShard", fn) + if err != nil { t.Fatalf("scan shard error: %v", err) } @@ -250,6 +214,46 @@ func TestScanShard_StopScan(t *testing.T) { } } +func TestScanShard_SkipCheckpoint(t *testing.T) { + var client = &kinesisClientMock{ + getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: records, + }, nil + }, + } + + var cp = &fakeCheckpoint{cache: map[string]string{}} + + c, err := New("myStreamName", WithClient(client), WithCheckpoint(cp)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var fn = func(r *Record) error { + if aws.StringValue(r.SequenceNumber) == "lastSeqNum" { + return SkipCheckpoint + } + return nil + } + + err = c.ScanShard(context.Background(), "myShard", fn) + if err != nil { + t.Fatalf("scan shard error: %v", err) + } + + val, err := cp.Get("myStreamName", "myShard") + if err != nil && val != "firstSeqNum" { + t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val) + } +} + func TestScanShard_ShardIsClosed(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { @@ -270,11 +274,12 @@ func TestScanShard_ShardIsClosed(t *testing.T) { t.Fatalf("new consumer error: %v", err) } - var fn = func(r *Record) ScanStatus { - return ScanStatus{StopScan: true} + var fn = func(r *Record) error { + return nil } - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { + err = c.ScanShard(context.Background(), "myShard", fn) + if err != nil { t.Fatalf("scan shard error: %v", err) } } diff --git a/examples/consumer/cp-dynamo/main.go b/examples/consumer/cp-dynamo/main.go index 98650e5..af885b3 100644 --- a/examples/consumer/cp-dynamo/main.go +++ b/examples/consumer/cp-dynamo/main.go @@ -54,7 +54,7 @@ func main() { } var ( - app = flag.String("app", "", "App name") + app = flag.String("app", "", "Consumer app name") stream = flag.String("stream", "", "Stream name") table = flag.String("table", "", "Checkpoint table name") ) @@ -103,11 +103,9 @@ func main() { }() // scan stream - err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus { + err = c.Scan(ctx, func(r *consumer.Record) error { fmt.Println(string(r.Data)) - - // continue scanning - return consumer.ScanStatus{} + return nil // continue scanning }) if err != nil { log.Log("scan error: %v", err) diff --git a/examples/consumer/cp-postgres/main.go b/examples/consumer/cp-postgres/main.go index c12a312..daf5720 100644 --- a/examples/consumer/cp-postgres/main.go +++ b/examples/consumer/cp-postgres/main.go @@ -15,7 +15,7 @@ import ( func main() { var ( - app = flag.String("app", "", "App name") + app = flag.String("app", "", "Consumer app name") stream = flag.String("stream", "", "Stream name") table = flag.String("table", "", "Table name") connStr = flag.String("connection", "", "Connection Str") @@ -53,11 +53,9 @@ func main() { }() // scan stream - err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus { + err = c.Scan(ctx, func(r *consumer.Record) error { fmt.Println(string(r.Data)) - - // continue scanning - return consumer.ScanStatus{} + return nil // continue scanning }) if err != nil { diff --git a/examples/consumer/cp-redis/main.go b/examples/consumer/cp-redis/main.go index b86d8c3..f199396 100644 --- a/examples/consumer/cp-redis/main.go +++ b/examples/consumer/cp-redis/main.go @@ -14,7 +14,7 @@ import ( func main() { var ( - app = flag.String("app", "", "App name") + app = flag.String("app", "", "Consumer app name") stream = flag.String("stream", "", "Stream name") ) flag.Parse() @@ -46,11 +46,9 @@ func main() { }() // scan stream - err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus { + err = c.Scan(ctx, func(r *consumer.Record) error { fmt.Println(string(r.Data)) - - // continue scanning - return consumer.ScanStatus{} + return nil // continue scanning }) if err != nil { log.Fatalf("scan error: %v", err) diff --git a/examples/producer/main.go b/examples/producer/main.go index 78c7e74..d59aa61 100644 --- a/examples/producer/main.go +++ b/examples/producer/main.go @@ -25,7 +25,12 @@ func main() { defer f.Close() var records []*kinesis.PutRecordsRequestEntry - var client = kinesis.New(session.New()) + + sess, err := session.NewSession(aws.NewConfig()) + if err != nil { + log.Fatal(err) + } + var client = kinesis.New(sess) // loop over file data b := bufio.NewScanner(f) diff --git a/options.go b/options.go new file mode 100644 index 0000000..0876931 --- /dev/null +++ b/options.go @@ -0,0 +1,41 @@ +package consumer + +import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + +// Option is used to override defaults when creating a new Consumer +type Option func(*Consumer) + +// WithCheckpoint overrides the default checkpoint +func WithCheckpoint(checkpoint Checkpoint) Option { + return func(c *Consumer) { + c.checkpoint = checkpoint + } +} + +// WithLogger overrides the default logger +func WithLogger(logger Logger) Option { + return func(c *Consumer) { + c.logger = logger + } +} + +// WithCounter overrides the default counter +func WithCounter(counter Counter) Option { + return func(c *Consumer) { + c.counter = counter + } +} + +// WithClient overrides the default client +func WithClient(client kinesisiface.KinesisAPI) Option { + return func(c *Consumer) { + c.client = client + } +} + +// ShardIteratorType overrides the starting point for the consumer +func WithShardIteratorType(t string) Option { + return func(c *Consumer) { + c.initialShardIteratorType = t + } +}