Leverage context cancellation for stopping scan

This commit is contained in:
Harlow Ward 2018-12-30 21:54:33 -08:00
parent 7d5601fbde
commit 5112f448ac
3 changed files with 79 additions and 77 deletions

View file

@ -55,13 +55,16 @@ func main() {
## ScanFunc ## ScanFunc
The `ScanFunc` receives a Kinesis Record and returns an `error` ScanFunc is the type of the function called for each message read
from the stream. The record argument contains the original record
returned from the AWS Kinesis library.
```go ```go
type ScanFunc func(*Record) error type ScanFunc func(r *Record) error
``` ```
Return `nil` to continue scanning, or choose from the custom error types for additional flow control. If an error is returned, scanning stops. The sole exception is when the
function returns the special value SkipCheckpoint.
```go ```go
// continue scanning // continue scanning
@ -70,13 +73,31 @@ return nil
// continue scanning, skip checkpoint // continue scanning, skip checkpoint
return consumer.SkipCheckpoint return consumer.SkipCheckpoint
// stop scanning, return nil
return consumer.StopScan
// stop scanning, return error // stop scanning, return error
return errors.New("my error, exit all scans") return errors.New("my error, exit all scans")
``` ```
Use context cancel to signal the scan to exit without error. For example if we wanted to gracefulloy exit the scan on interrupt.
```go
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
// context with cancel
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-signals
cancel() // call cancellation
}()
err := c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data))
return nil // continue scanning
})
```
## Checkpoint ## Checkpoint
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback. To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback.

View file

@ -67,7 +67,7 @@ type Consumer struct {
// returned from the AWS Kinesis library. // returned from the AWS Kinesis library.
// //
// If an error is returned, scanning stops. The sole exception is when the // If an error is returned, scanning stops. The sole exception is when the
// function returns the special value SkipCheckpoint or StopScan. // function returns the special value SkipCheckpoint.
type ScanFunc func(*Record) error type ScanFunc func(*Record) error
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that // SkipCheckpoint is used as a return value from ScanFuncs to indicate that
@ -75,11 +75,6 @@ type ScanFunc func(*Record) error
// as an error by any function. // as an error by any function.
var SkipCheckpoint = errors.New("skip checkpoint") var SkipCheckpoint = errors.New("skip checkpoint")
// StopScan is used as a return value from ScanFuncs to indicate that
// the we should stop scanning the current shard. It is not returned
// as an error by any function.
var StopScan = errors.New("stop scan")
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc // Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
// is passed through to each of the goroutines and called with each message pulled from // is passed through to each of the goroutines and called with each message pulled from
// the stream. // the stream.
@ -164,24 +159,26 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
continue continue
} }
// call callback func with each record from response // callback func with each record
for _, r := range resp.Records { for _, r := range resp.Records {
lastSeqNum = *r.SequenceNumber select {
c.counter.Add("records", 1) case <-ctx.Done():
return nil
default:
err := fn(r)
if err := fn(r); err != nil { if err != nil && err != SkipCheckpoint {
switch err {
case StopScan:
return nil
case SkipCheckpoint:
continue
default:
return err return err
} }
}
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { if err != SkipCheckpoint {
return err if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err
}
}
c.counter.Add("records", 1)
lastSeqNum = *r.SequenceNumber
} }
} }
@ -221,9 +218,10 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
NextToken: resp.NextToken, NextToken: resp.NextToken,
} }
} }
return ss, nil
} }
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) { func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{ params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID), ShardId: aws.String(shardID),
StreamName: aws.String(streamName), StreamName: aws.String(streamName),
@ -236,9 +234,6 @@ func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*st
params.ShardIteratorType = aws.String(c.initialShardIteratorType) params.ShardIteratorType = aws.String(c.initialShardIteratorType)
} }
resp, err := c.client.GetShardIterator(params) res, err := c.client.GetShardIterator(params)
if err != nil { return res.ShardIterator, err
return nil, err
}
return resp.ShardIterator, nil
} }

View file

@ -11,23 +11,24 @@ import (
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
) )
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
if _, err := New("myStreamName"); err != nil { if _, err := New("myStreamName"); err != nil {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
} }
func TestConsumer_Scan(t *testing.T) { func TestScan(t *testing.T) {
records := []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
client := &kinesisClientMock{ client := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
@ -75,11 +76,13 @@ func TestConsumer_Scan(t *testing.T) {
} }
if resultData != "firstDatalastData" { if resultData != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData) t.Errorf("callback error expected %s, got %s", "FirstLast", resultData)
} }
if fnCallCounter != 2 { if fnCallCounter != 2 {
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter) t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
} }
if val := ctr.counter; val != 2 { if val := ctr.counter; val != 2 {
t.Errorf("counter error expected %d, got %d", 2, val) t.Errorf("counter error expected %d, got %d", 2, val)
} }
@ -90,7 +93,7 @@ func TestConsumer_Scan(t *testing.T) {
} }
} }
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { func TestScan_NoShardsAvailable(t *testing.T) {
client := &kinesisClientMock{ client := &kinesisClientMock{
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{ return &kinesis.ListShardsOutput{
@ -114,17 +117,6 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
} }
func TestScanShard(t *testing.T) { func TestScanShard(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
@ -182,18 +174,7 @@ func TestScanShard(t *testing.T) {
} }
} }
func TestScanShard_StopScan(t *testing.T) { func TestScanShard_Cancellation(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
@ -208,19 +189,23 @@ func TestScanShard_StopScan(t *testing.T) {
}, },
} }
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client)) c, err := New("myStreamName", WithClient(client))
if err != nil { if err != nil {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
// callback fn appends record data err = c.ScanShard(ctx, "myShard", fn)
var res string if err != nil {
var fn = func(r *Record) error {
res += string(r.Data)
return StopScan
}
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
t.Fatalf("scan shard error: %v", err) t.Fatalf("scan shard error: %v", err)
} }
@ -250,10 +235,11 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
} }
var fn = func(r *Record) error { var fn = func(r *Record) error {
return StopScan return nil
} }
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { err = c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err) t.Fatalf("scan shard error: %v", err)
} }
} }