This commit is contained in:
Harlow Ward 2019-01-03 19:34:24 -08:00
parent 5112f448ac
commit 4bc414e216
2 changed files with 41 additions and 3 deletions

View file

@ -139,7 +139,6 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
c.logger.Log("scanning", shardID, lastSeqNum) c.logger.Log("scanning", shardID, lastSeqNum)
// loop until
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -159,14 +158,13 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
continue continue
} }
// callback func with each record // loop over records, call callback func
for _, r := range resp.Records { for _, r := range resp.Records {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
default: default:
err := fn(r) err := fn(r)
if err != nil && err != SkipCheckpoint { if err != nil && err != SkipCheckpoint {
return err return err
} }

View file

@ -214,6 +214,46 @@ func TestScanShard_Cancellation(t *testing.T) {
} }
} }
func TestScanShard_SkipCheckpoint(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
}
var cp = &fakeCheckpoint{cache: map[string]string{}}
c, err := New("myStreamName", WithClient(client), WithCheckpoint(cp))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var fn = func(r *Record) error {
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
return SkipCheckpoint
}
return nil
}
err = c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "firstSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
}
}
func TestScanShard_ShardIsClosed(t *testing.T) { func TestScanShard_ShardIsClosed(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {