Leverage context cancellation for stopping scan

This commit is contained in:
Harlow Ward 2018-12-30 21:54:33 -08:00
parent 7d5601fbde
commit 5112f448ac
3 changed files with 79 additions and 77 deletions

View file

@ -55,13 +55,16 @@ func main() {
## ScanFunc
The `ScanFunc` receives a Kinesis Record and returns an `error`
ScanFunc is the type of the function called for each message read
from the stream. The record argument contains the original record
returned from the AWS Kinesis library.
```go
type ScanFunc func(*Record) error
type ScanFunc func(r *Record) error
```
Return `nil` to continue scanning, or choose from the custom error types for additional flow control.
If an error is returned, scanning stops. The sole exception is when the
function returns the special value SkipCheckpoint.
```go
// continue scanning
@ -70,13 +73,31 @@ return nil
// continue scanning, skip checkpoint
return consumer.SkipCheckpoint
// stop scanning, return nil
return consumer.StopScan
// stop scanning, return error
return errors.New("my error, exit all scans")
```
Use context cancel to signal the scan to exit without error. For example if we wanted to gracefulloy exit the scan on interrupt.
```go
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
// context with cancel
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-signals
cancel() // call cancellation
}()
err := c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data))
return nil // continue scanning
})
```
## Checkpoint
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback.

View file

@ -67,7 +67,7 @@ type Consumer struct {
// returned from the AWS Kinesis library.
//
// If an error is returned, scanning stops. The sole exception is when the
// function returns the special value SkipCheckpoint or StopScan.
// function returns the special value SkipCheckpoint.
type ScanFunc func(*Record) error
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
@ -75,11 +75,6 @@ type ScanFunc func(*Record) error
// as an error by any function.
var SkipCheckpoint = errors.New("skip checkpoint")
// StopScan is used as a return value from ScanFuncs to indicate that
// the we should stop scanning the current shard. It is not returned
// as an error by any function.
var StopScan = errors.New("stop scan")
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
// is passed through to each of the goroutines and called with each message pulled from
// the stream.
@ -164,24 +159,26 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
continue
}
// call callback func with each record from response
// callback func with each record
for _, r := range resp.Records {
lastSeqNum = *r.SequenceNumber
c.counter.Add("records", 1)
select {
case <-ctx.Done():
return nil
default:
err := fn(r)
if err := fn(r); err != nil {
switch err {
case StopScan:
return nil
case SkipCheckpoint:
continue
default:
if err != nil && err != SkipCheckpoint {
return err
}
}
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err
if err != SkipCheckpoint {
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err
}
}
c.counter.Add("records", 1)
lastSeqNum = *r.SequenceNumber
}
}
@ -221,9 +218,10 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
NextToken: resp.NextToken,
}
}
return ss, nil
}
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
@ -236,9 +234,6 @@ func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*st
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
}
resp, err := c.client.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
res, err := c.client.GetShardIterator(params)
return res.ShardIterator, err
}

View file

@ -11,23 +11,24 @@ import (
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
func TestNew(t *testing.T) {
if _, err := New("myStreamName"); err != nil {
t.Fatalf("new consumer error: %v", err)
}
}
func TestConsumer_Scan(t *testing.T) {
records := []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
func TestScan(t *testing.T) {
client := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
@ -75,11 +76,13 @@ func TestConsumer_Scan(t *testing.T) {
}
if resultData != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
t.Errorf("callback error expected %s, got %s", "FirstLast", resultData)
}
if fnCallCounter != 2 {
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
}
if val := ctr.counter; val != 2 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
@ -90,7 +93,7 @@ func TestConsumer_Scan(t *testing.T) {
}
}
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
func TestScan_NoShardsAvailable(t *testing.T) {
client := &kinesisClientMock{
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
@ -114,17 +117,6 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
}
func TestScanShard(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
@ -182,18 +174,7 @@ func TestScanShard(t *testing.T) {
}
}
func TestScanShard_StopScan(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
func TestScanShard_Cancellation(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
@ -208,19 +189,23 @@ func TestScanShard_StopScan(t *testing.T) {
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
// callback fn appends record data
var res string
var fn = func(r *Record) error {
res += string(r.Data)
return StopScan
}
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
err = c.ScanShard(ctx, "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
@ -250,10 +235,11 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
}
var fn = func(r *Record) error {
return StopScan
return nil
}
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
err = c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
}