diff --git a/README.md b/README.md index 6cf3ae6..2f67308 100644 --- a/README.md +++ b/README.md @@ -38,9 +38,13 @@ func main() { } // start - err = c.Scan(context.TODO(), func(r *consumer.Record) bool { + err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanError { fmt.Println(string(r.Data)) - return true // continue scanning + // continue scanning + return consumer.ScanError{ + StopScan: false, // true to stop scan + SkipCheckpoint: false, // true to skip checkpoint + } }) if err != nil { log.Fatalf("scan error: %v", err) @@ -53,7 +57,7 @@ func main() { ## Checkpoint -To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. +To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback. This will allow consumers to re-launch and pick up at the position in the stream where they left off. diff --git a/consumer.go b/consumer.go index cc766d7..7d69c99 100644 --- a/consumer.go +++ b/consumer.go @@ -10,6 +10,14 @@ import ( "github.com/aws/aws-sdk-go/service/kinesis" ) +// ScanError signals the consumer if we should continue scanning for next record +// and whether to checkpoint. +type ScanError struct { + Error error + StopScan bool + SkipCheckpoint bool +} + // Record is an alias of record returned from kinesis library type Record = kinesis.Record @@ -111,7 +119,7 @@ type Consumer struct { // Scan scans each of the shards of the stream, calls the callback // func with each of the kinesis records. -func (c *Consumer) Scan(ctx context.Context, fn func(*Record) bool) error { +func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error { shardIDs, err := c.client.GetShardIDs(c.streamName) if err != nil { return fmt.Errorf("get shards error: %v", err) @@ -156,31 +164,39 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) bool) error { // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints the progress of scan. // Note: Returning `false` from the callback func will end the scan. -func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Record) bool) (err error) { +func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Record) ScanError) (err error) { lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) if err != nil { return fmt.Errorf("get checkpoint error: %v", err) } c.logger.Println("scanning", shardID, lastSeqNum) - // get records recc, errc, err := c.client.GetRecords(ctx, c.streamName, shardID, lastSeqNum) if err != nil { return fmt.Errorf("get records error: %v", err) } - // loop records for r := range recc { - if ok := fn(r); !ok { + scanError := fn(r) + // It will be nicer if this can be reported with checkpoint error + err = scanError.Error + + // Skip invalid state + if scanError.StopScan && scanError.SkipCheckpoint { + continue + } + + if scanError.StopScan { break } - c.counter.Add("records", 1) - - err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber) - if err != nil { - return fmt.Errorf("set checkpoint error: %v", err) + if !scanError.SkipCheckpoint { + c.counter.Add("records", 1) + err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber) + if err != nil { + return fmt.Errorf("set checkpoint error: %v", err) + } } } diff --git a/consumer_test.go b/consumer_test.go index d3feb2c..25e8e53 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -45,9 +45,14 @@ func TestScanShard(t *testing.T) { // callback fn simply appends the record data to result string var ( resultData string - fn = func(r *Record) bool { + fn = func(r *Record) ScanError { resultData += string(r.Data) - return true + err := errors.New("some error happened") + return ScanError{ + Error: err, + StopScan: false, + SkipCheckpoint: false, + } } ) diff --git a/examples/consumer/main.go b/examples/consumer/main.go index dad2e30..4602564 100644 --- a/examples/consumer/main.go +++ b/examples/consumer/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "expvar" "flag" "fmt" @@ -89,9 +90,15 @@ func main() { }() // scan stream - err = c.Scan(ctx, func(r *consumer.Record) bool { + err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanError { fmt.Println(string(r.Data)) - return true // continue scanning + err := errors.New("some error happened") + // continue scanning + return consumer.ScanError{ + Error: err, + StopScan: true, + SkipCheckpoint: false, + } }) if err != nil { log.Fatalf("scan error: %v", err)