diff --git a/README.md b/README.md index 9ced52c..6052941 100644 --- a/README.md +++ b/README.md @@ -55,13 +55,16 @@ func main() { ## ScanFunc -The `ScanFunc` receives a Kinesis Record and returns an `error` +ScanFunc is the type of the function called for each message read +from the stream. The record argument contains the original record +returned from the AWS Kinesis library. ```go -type ScanFunc func(*Record) error +type ScanFunc func(r *Record) error ``` -Return `nil` to continue scanning, or choose from the custom error types for additional flow control. +If an error is returned, scanning stops. The sole exception is when the +function returns the special value SkipCheckpoint. ```go // continue scanning @@ -70,13 +73,31 @@ return nil // continue scanning, skip checkpoint return consumer.SkipCheckpoint -// stop scanning, return nil -return consumer.StopScan - // stop scanning, return error return errors.New("my error, exit all scans") ``` +Use context cancel to signal the scan to exit without error. For example if we wanted to gracefulloy exit the scan on interrupt. + +```go +// trap SIGINT, wait to trigger shutdown +signals := make(chan os.Signal, 1) +signal.Notify(signals, os.Interrupt) + +// context with cancel +ctx, cancel := context.WithCancel(context.Background()) + +go func() { + <-signals + cancel() // call cancellation +}() + +err := c.Scan(ctx, func(r *consumer.Record) error { + fmt.Println(string(r.Data)) + return nil // continue scanning +}) +``` + ## Checkpoint To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback. diff --git a/consumer.go b/consumer.go index 01fe841..dd106f4 100644 --- a/consumer.go +++ b/consumer.go @@ -67,7 +67,7 @@ type Consumer struct { // returned from the AWS Kinesis library. // // If an error is returned, scanning stops. The sole exception is when the -// function returns the special value SkipCheckpoint or StopScan. +// function returns the special value SkipCheckpoint. type ScanFunc func(*Record) error // SkipCheckpoint is used as a return value from ScanFuncs to indicate that @@ -75,11 +75,6 @@ type ScanFunc func(*Record) error // as an error by any function. var SkipCheckpoint = errors.New("skip checkpoint") -// StopScan is used as a return value from ScanFuncs to indicate that -// the we should stop scanning the current shard. It is not returned -// as an error by any function. -var StopScan = errors.New("stop scan") - // Scan launches a goroutine to process each of the shards in the stream. The ScanFunc // is passed through to each of the goroutines and called with each message pulled from // the stream. @@ -164,24 +159,26 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e continue } - // call callback func with each record from response + // callback func with each record for _, r := range resp.Records { - lastSeqNum = *r.SequenceNumber - c.counter.Add("records", 1) + select { + case <-ctx.Done(): + return nil + default: + err := fn(r) - if err := fn(r); err != nil { - switch err { - case StopScan: - return nil - case SkipCheckpoint: - continue - default: + if err != nil && err != SkipCheckpoint { return err } - } - if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { - return err + if err != SkipCheckpoint { + if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { + return err + } + } + + c.counter.Add("records", 1) + lastSeqNum = *r.SequenceNumber } } @@ -221,9 +218,10 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) { NextToken: resp.NextToken, } } + return ss, nil } -func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) { +func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) { params := &kinesis.GetShardIteratorInput{ ShardId: aws.String(shardID), StreamName: aws.String(streamName), @@ -236,9 +234,6 @@ func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*st params.ShardIteratorType = aws.String(c.initialShardIteratorType) } - resp, err := c.client.GetShardIterator(params) - if err != nil { - return nil, err - } - return resp.ShardIterator, nil + res, err := c.client.GetShardIterator(params) + return res.ShardIterator, err } diff --git a/consumer_test.go b/consumer_test.go index 9f23160..d53bca5 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -11,23 +11,24 @@ import ( "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" ) +var records = []*kinesis.Record{ + { + Data: []byte("firstData"), + SequenceNumber: aws.String("firstSeqNum"), + }, + { + Data: []byte("lastData"), + SequenceNumber: aws.String("lastSeqNum"), + }, +} + func TestNew(t *testing.T) { if _, err := New("myStreamName"); err != nil { t.Fatalf("new consumer error: %v", err) } } -func TestConsumer_Scan(t *testing.T) { - records := []*kinesis.Record{ - { - Data: []byte("firstData"), - SequenceNumber: aws.String("firstSeqNum"), - }, - { - Data: []byte("lastData"), - SequenceNumber: aws.String("lastSeqNum"), - }, - } +func TestScan(t *testing.T) { client := &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ @@ -75,11 +76,13 @@ func TestConsumer_Scan(t *testing.T) { } if resultData != "firstDatalastData" { - t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData) + t.Errorf("callback error expected %s, got %s", "FirstLast", resultData) } + if fnCallCounter != 2 { t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter) } + if val := ctr.counter; val != 2 { t.Errorf("counter error expected %d, got %d", 2, val) } @@ -90,7 +93,7 @@ func TestConsumer_Scan(t *testing.T) { } } -func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { +func TestScan_NoShardsAvailable(t *testing.T) { client := &kinesisClientMock{ listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { return &kinesis.ListShardsOutput{ @@ -114,17 +117,6 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { } func TestScanShard(t *testing.T) { - var records = []*kinesis.Record{ - { - Data: []byte("firstData"), - SequenceNumber: aws.String("firstSeqNum"), - }, - { - Data: []byte("lastData"), - SequenceNumber: aws.String("lastSeqNum"), - }, - } - var client = &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ @@ -182,18 +174,7 @@ func TestScanShard(t *testing.T) { } } -func TestScanShard_StopScan(t *testing.T) { - var records = []*kinesis.Record{ - { - Data: []byte("firstData"), - SequenceNumber: aws.String("firstSeqNum"), - }, - { - Data: []byte("lastData"), - SequenceNumber: aws.String("lastSeqNum"), - }, - } - +func TestScanShard_Cancellation(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ @@ -208,19 +189,23 @@ func TestScanShard_StopScan(t *testing.T) { }, } + // use cancel func to signal shutdown + ctx, cancel := context.WithCancel(context.Background()) + + var res string + var fn = func(r *Record) error { + res += string(r.Data) + cancel() // simulate cancellation while processing first record + return nil + } + c, err := New("myStreamName", WithClient(client)) if err != nil { t.Fatalf("new consumer error: %v", err) } - // callback fn appends record data - var res string - var fn = func(r *Record) error { - res += string(r.Data) - return StopScan - } - - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { + err = c.ScanShard(ctx, "myShard", fn) + if err != nil { t.Fatalf("scan shard error: %v", err) } @@ -250,10 +235,11 @@ func TestScanShard_ShardIsClosed(t *testing.T) { } var fn = func(r *Record) error { - return StopScan + return nil } - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { + err = c.ScanShard(context.Background(), "myShard", fn) + if err != nil { t.Fatalf("scan shard error: %v", err) } }