Leverage context cancellation for stopping scan
This commit is contained in:
parent
7d5601fbde
commit
5112f448ac
3 changed files with 79 additions and 77 deletions
33
README.md
33
README.md
|
|
@ -55,13 +55,16 @@ func main() {
|
|||
|
||||
## ScanFunc
|
||||
|
||||
The `ScanFunc` receives a Kinesis Record and returns an `error`
|
||||
ScanFunc is the type of the function called for each message read
|
||||
from the stream. The record argument contains the original record
|
||||
returned from the AWS Kinesis library.
|
||||
|
||||
```go
|
||||
type ScanFunc func(*Record) error
|
||||
type ScanFunc func(r *Record) error
|
||||
```
|
||||
|
||||
Return `nil` to continue scanning, or choose from the custom error types for additional flow control.
|
||||
If an error is returned, scanning stops. The sole exception is when the
|
||||
function returns the special value SkipCheckpoint.
|
||||
|
||||
```go
|
||||
// continue scanning
|
||||
|
|
@ -70,13 +73,31 @@ return nil
|
|||
// continue scanning, skip checkpoint
|
||||
return consumer.SkipCheckpoint
|
||||
|
||||
// stop scanning, return nil
|
||||
return consumer.StopScan
|
||||
|
||||
// stop scanning, return error
|
||||
return errors.New("my error, exit all scans")
|
||||
```
|
||||
|
||||
Use context cancel to signal the scan to exit without error. For example if we wanted to gracefulloy exit the scan on interrupt.
|
||||
|
||||
```go
|
||||
// trap SIGINT, wait to trigger shutdown
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, os.Interrupt)
|
||||
|
||||
// context with cancel
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
<-signals
|
||||
cancel() // call cancellation
|
||||
}()
|
||||
|
||||
err := c.Scan(ctx, func(r *consumer.Record) error {
|
||||
fmt.Println(string(r.Data))
|
||||
return nil // continue scanning
|
||||
})
|
||||
```
|
||||
|
||||
## Checkpoint
|
||||
|
||||
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback.
|
||||
|
|
|
|||
45
consumer.go
45
consumer.go
|
|
@ -67,7 +67,7 @@ type Consumer struct {
|
|||
// returned from the AWS Kinesis library.
|
||||
//
|
||||
// If an error is returned, scanning stops. The sole exception is when the
|
||||
// function returns the special value SkipCheckpoint or StopScan.
|
||||
// function returns the special value SkipCheckpoint.
|
||||
type ScanFunc func(*Record) error
|
||||
|
||||
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
|
||||
|
|
@ -75,11 +75,6 @@ type ScanFunc func(*Record) error
|
|||
// as an error by any function.
|
||||
var SkipCheckpoint = errors.New("skip checkpoint")
|
||||
|
||||
// StopScan is used as a return value from ScanFuncs to indicate that
|
||||
// the we should stop scanning the current shard. It is not returned
|
||||
// as an error by any function.
|
||||
var StopScan = errors.New("stop scan")
|
||||
|
||||
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
|
||||
// is passed through to each of the goroutines and called with each message pulled from
|
||||
// the stream.
|
||||
|
|
@ -164,24 +159,26 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
continue
|
||||
}
|
||||
|
||||
// call callback func with each record from response
|
||||
// callback func with each record
|
||||
for _, r := range resp.Records {
|
||||
lastSeqNum = *r.SequenceNumber
|
||||
c.counter.Add("records", 1)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
err := fn(r)
|
||||
|
||||
if err := fn(r); err != nil {
|
||||
switch err {
|
||||
case StopScan:
|
||||
return nil
|
||||
case SkipCheckpoint:
|
||||
continue
|
||||
default:
|
||||
if err != nil && err != SkipCheckpoint {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
return err
|
||||
if err != SkipCheckpoint {
|
||||
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.counter.Add("records", 1)
|
||||
lastSeqNum = *r.SequenceNumber
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -221,9 +218,10 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
|||
NextToken: resp.NextToken,
|
||||
}
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
|
||||
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
|
||||
params := &kinesis.GetShardIteratorInput{
|
||||
ShardId: aws.String(shardID),
|
||||
StreamName: aws.String(streamName),
|
||||
|
|
@ -236,9 +234,6 @@ func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*st
|
|||
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
|
||||
}
|
||||
|
||||
resp, err := c.client.GetShardIterator(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.ShardIterator, nil
|
||||
res, err := c.client.GetShardIterator(params)
|
||||
return res.ShardIterator, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,23 +11,24 @@ import (
|
|||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||
)
|
||||
|
||||
var records = []*kinesis.Record{
|
||||
{
|
||||
Data: []byte("firstData"),
|
||||
SequenceNumber: aws.String("firstSeqNum"),
|
||||
},
|
||||
{
|
||||
Data: []byte("lastData"),
|
||||
SequenceNumber: aws.String("lastSeqNum"),
|
||||
},
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
if _, err := New("myStreamName"); err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumer_Scan(t *testing.T) {
|
||||
records := []*kinesis.Record{
|
||||
{
|
||||
Data: []byte("firstData"),
|
||||
SequenceNumber: aws.String("firstSeqNum"),
|
||||
},
|
||||
{
|
||||
Data: []byte("lastData"),
|
||||
SequenceNumber: aws.String("lastSeqNum"),
|
||||
},
|
||||
}
|
||||
func TestScan(t *testing.T) {
|
||||
client := &kinesisClientMock{
|
||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return &kinesis.GetShardIteratorOutput{
|
||||
|
|
@ -75,11 +76,13 @@ func TestConsumer_Scan(t *testing.T) {
|
|||
}
|
||||
|
||||
if resultData != "firstDatalastData" {
|
||||
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
|
||||
t.Errorf("callback error expected %s, got %s", "FirstLast", resultData)
|
||||
}
|
||||
|
||||
if fnCallCounter != 2 {
|
||||
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
||||
}
|
||||
|
||||
if val := ctr.counter; val != 2 {
|
||||
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||
}
|
||||
|
|
@ -90,7 +93,7 @@ func TestConsumer_Scan(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
||||
func TestScan_NoShardsAvailable(t *testing.T) {
|
||||
client := &kinesisClientMock{
|
||||
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
||||
return &kinesis.ListShardsOutput{
|
||||
|
|
@ -114,17 +117,6 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestScanShard(t *testing.T) {
|
||||
var records = []*kinesis.Record{
|
||||
{
|
||||
Data: []byte("firstData"),
|
||||
SequenceNumber: aws.String("firstSeqNum"),
|
||||
},
|
||||
{
|
||||
Data: []byte("lastData"),
|
||||
SequenceNumber: aws.String("lastSeqNum"),
|
||||
},
|
||||
}
|
||||
|
||||
var client = &kinesisClientMock{
|
||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return &kinesis.GetShardIteratorOutput{
|
||||
|
|
@ -182,18 +174,7 @@ func TestScanShard(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestScanShard_StopScan(t *testing.T) {
|
||||
var records = []*kinesis.Record{
|
||||
{
|
||||
Data: []byte("firstData"),
|
||||
SequenceNumber: aws.String("firstSeqNum"),
|
||||
},
|
||||
{
|
||||
Data: []byte("lastData"),
|
||||
SequenceNumber: aws.String("lastSeqNum"),
|
||||
},
|
||||
}
|
||||
|
||||
func TestScanShard_Cancellation(t *testing.T) {
|
||||
var client = &kinesisClientMock{
|
||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return &kinesis.GetShardIteratorOutput{
|
||||
|
|
@ -208,19 +189,23 @@ func TestScanShard_StopScan(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
// use cancel func to signal shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var res string
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
cancel() // simulate cancellation while processing first record
|
||||
return nil
|
||||
}
|
||||
|
||||
c, err := New("myStreamName", WithClient(client))
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
// callback fn appends record data
|
||||
var res string
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
return StopScan
|
||||
}
|
||||
|
||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
||||
err = c.ScanShard(ctx, "myShard", fn)
|
||||
if err != nil {
|
||||
t.Fatalf("scan shard error: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -250,10 +235,11 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
|
|||
}
|
||||
|
||||
var fn = func(r *Record) error {
|
||||
return StopScan
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
||||
err = c.ScanShard(context.Background(), "myShard", fn)
|
||||
if err != nil {
|
||||
t.Fatalf("scan shard error: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue