Leverage context cancellation for stopping scan
This commit is contained in:
parent
7d5601fbde
commit
5112f448ac
3 changed files with 79 additions and 77 deletions
33
README.md
33
README.md
|
|
@ -55,13 +55,16 @@ func main() {
|
||||||
|
|
||||||
## ScanFunc
|
## ScanFunc
|
||||||
|
|
||||||
The `ScanFunc` receives a Kinesis Record and returns an `error`
|
ScanFunc is the type of the function called for each message read
|
||||||
|
from the stream. The record argument contains the original record
|
||||||
|
returned from the AWS Kinesis library.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
type ScanFunc func(*Record) error
|
type ScanFunc func(r *Record) error
|
||||||
```
|
```
|
||||||
|
|
||||||
Return `nil` to continue scanning, or choose from the custom error types for additional flow control.
|
If an error is returned, scanning stops. The sole exception is when the
|
||||||
|
function returns the special value SkipCheckpoint.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// continue scanning
|
// continue scanning
|
||||||
|
|
@ -70,13 +73,31 @@ return nil
|
||||||
// continue scanning, skip checkpoint
|
// continue scanning, skip checkpoint
|
||||||
return consumer.SkipCheckpoint
|
return consumer.SkipCheckpoint
|
||||||
|
|
||||||
// stop scanning, return nil
|
|
||||||
return consumer.StopScan
|
|
||||||
|
|
||||||
// stop scanning, return error
|
// stop scanning, return error
|
||||||
return errors.New("my error, exit all scans")
|
return errors.New("my error, exit all scans")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Use context cancel to signal the scan to exit without error. For example if we wanted to gracefulloy exit the scan on interrupt.
|
||||||
|
|
||||||
|
```go
|
||||||
|
// trap SIGINT, wait to trigger shutdown
|
||||||
|
signals := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(signals, os.Interrupt)
|
||||||
|
|
||||||
|
// context with cancel
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-signals
|
||||||
|
cancel() // call cancellation
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := c.Scan(ctx, func(r *consumer.Record) error {
|
||||||
|
fmt.Println(string(r.Data))
|
||||||
|
return nil // continue scanning
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
## Checkpoint
|
## Checkpoint
|
||||||
|
|
||||||
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback.
|
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback.
|
||||||
|
|
|
||||||
45
consumer.go
45
consumer.go
|
|
@ -67,7 +67,7 @@ type Consumer struct {
|
||||||
// returned from the AWS Kinesis library.
|
// returned from the AWS Kinesis library.
|
||||||
//
|
//
|
||||||
// If an error is returned, scanning stops. The sole exception is when the
|
// If an error is returned, scanning stops. The sole exception is when the
|
||||||
// function returns the special value SkipCheckpoint or StopScan.
|
// function returns the special value SkipCheckpoint.
|
||||||
type ScanFunc func(*Record) error
|
type ScanFunc func(*Record) error
|
||||||
|
|
||||||
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
|
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
|
||||||
|
|
@ -75,11 +75,6 @@ type ScanFunc func(*Record) error
|
||||||
// as an error by any function.
|
// as an error by any function.
|
||||||
var SkipCheckpoint = errors.New("skip checkpoint")
|
var SkipCheckpoint = errors.New("skip checkpoint")
|
||||||
|
|
||||||
// StopScan is used as a return value from ScanFuncs to indicate that
|
|
||||||
// the we should stop scanning the current shard. It is not returned
|
|
||||||
// as an error by any function.
|
|
||||||
var StopScan = errors.New("stop scan")
|
|
||||||
|
|
||||||
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
|
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
|
||||||
// is passed through to each of the goroutines and called with each message pulled from
|
// is passed through to each of the goroutines and called with each message pulled from
|
||||||
// the stream.
|
// the stream.
|
||||||
|
|
@ -164,24 +159,26 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// call callback func with each record from response
|
// callback func with each record
|
||||||
for _, r := range resp.Records {
|
for _, r := range resp.Records {
|
||||||
lastSeqNum = *r.SequenceNumber
|
select {
|
||||||
c.counter.Add("records", 1)
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
err := fn(r)
|
||||||
|
|
||||||
if err := fn(r); err != nil {
|
if err != nil && err != SkipCheckpoint {
|
||||||
switch err {
|
|
||||||
case StopScan:
|
|
||||||
return nil
|
|
||||||
case SkipCheckpoint:
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
if err != SkipCheckpoint {
|
||||||
return err
|
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.counter.Add("records", 1)
|
||||||
|
lastSeqNum = *r.SequenceNumber
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -221,9 +218,10 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||||
NextToken: resp.NextToken,
|
NextToken: resp.NextToken,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
|
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
|
||||||
params := &kinesis.GetShardIteratorInput{
|
params := &kinesis.GetShardIteratorInput{
|
||||||
ShardId: aws.String(shardID),
|
ShardId: aws.String(shardID),
|
||||||
StreamName: aws.String(streamName),
|
StreamName: aws.String(streamName),
|
||||||
|
|
@ -236,9 +234,6 @@ func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*st
|
||||||
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
|
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.client.GetShardIterator(params)
|
res, err := c.client.GetShardIterator(params)
|
||||||
if err != nil {
|
return res.ShardIterator, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return resp.ShardIterator, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,23 +11,24 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var records = []*kinesis.Record{
|
||||||
|
{
|
||||||
|
Data: []byte("firstData"),
|
||||||
|
SequenceNumber: aws.String("firstSeqNum"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Data: []byte("lastData"),
|
||||||
|
SequenceNumber: aws.String("lastSeqNum"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
if _, err := New("myStreamName"); err != nil {
|
if _, err := New("myStreamName"); err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConsumer_Scan(t *testing.T) {
|
func TestScan(t *testing.T) {
|
||||||
records := []*kinesis.Record{
|
|
||||||
{
|
|
||||||
Data: []byte("firstData"),
|
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Data: []byte("lastData"),
|
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
client := &kinesisClientMock{
|
client := &kinesisClientMock{
|
||||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return &kinesis.GetShardIteratorOutput{
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
|
@ -75,11 +76,13 @@ func TestConsumer_Scan(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if resultData != "firstDatalastData" {
|
if resultData != "firstDatalastData" {
|
||||||
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
|
t.Errorf("callback error expected %s, got %s", "FirstLast", resultData)
|
||||||
}
|
}
|
||||||
|
|
||||||
if fnCallCounter != 2 {
|
if fnCallCounter != 2 {
|
||||||
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if val := ctr.counter; val != 2 {
|
if val := ctr.counter; val != 2 {
|
||||||
t.Errorf("counter error expected %d, got %d", 2, val)
|
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||||
}
|
}
|
||||||
|
|
@ -90,7 +93,7 @@ func TestConsumer_Scan(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
func TestScan_NoShardsAvailable(t *testing.T) {
|
||||||
client := &kinesisClientMock{
|
client := &kinesisClientMock{
|
||||||
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
||||||
return &kinesis.ListShardsOutput{
|
return &kinesis.ListShardsOutput{
|
||||||
|
|
@ -114,17 +117,6 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanShard(t *testing.T) {
|
func TestScanShard(t *testing.T) {
|
||||||
var records = []*kinesis.Record{
|
|
||||||
{
|
|
||||||
Data: []byte("firstData"),
|
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Data: []byte("lastData"),
|
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var client = &kinesisClientMock{
|
var client = &kinesisClientMock{
|
||||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return &kinesis.GetShardIteratorOutput{
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
|
@ -182,18 +174,7 @@ func TestScanShard(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanShard_StopScan(t *testing.T) {
|
func TestScanShard_Cancellation(t *testing.T) {
|
||||||
var records = []*kinesis.Record{
|
|
||||||
{
|
|
||||||
Data: []byte("firstData"),
|
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Data: []byte("lastData"),
|
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var client = &kinesisClientMock{
|
var client = &kinesisClientMock{
|
||||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return &kinesis.GetShardIteratorOutput{
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
|
@ -208,19 +189,23 @@ func TestScanShard_StopScan(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// use cancel func to signal shutdown
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
var res string
|
||||||
|
var fn = func(r *Record) error {
|
||||||
|
res += string(r.Data)
|
||||||
|
cancel() // simulate cancellation while processing first record
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
c, err := New("myStreamName", WithClient(client))
|
c, err := New("myStreamName", WithClient(client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// callback fn appends record data
|
err = c.ScanShard(ctx, "myShard", fn)
|
||||||
var res string
|
if err != nil {
|
||||||
var fn = func(r *Record) error {
|
|
||||||
res += string(r.Data)
|
|
||||||
return StopScan
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
|
||||||
t.Fatalf("scan shard error: %v", err)
|
t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -250,10 +235,11 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var fn = func(r *Record) error {
|
var fn = func(r *Record) error {
|
||||||
return StopScan
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
err = c.ScanShard(context.Background(), "myShard", fn)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("scan shard error: %v", err)
|
t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue