From 4bc414e216cd40978aac9083d8a74c4557a25d9e Mon Sep 17 00:00:00 2001 From: Harlow Ward Date: Thu, 3 Jan 2019 19:34:24 -0800 Subject: [PATCH] wip --- consumer.go | 4 +--- consumer_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/consumer.go b/consumer.go index dd106f4..676073a 100644 --- a/consumer.go +++ b/consumer.go @@ -139,7 +139,6 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e c.logger.Log("scanning", shardID, lastSeqNum) - // loop until for { select { case <-ctx.Done(): @@ -159,14 +158,13 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e continue } - // callback func with each record + // loop over records, call callback func for _, r := range resp.Records { select { case <-ctx.Done(): return nil default: err := fn(r) - if err != nil && err != SkipCheckpoint { return err } diff --git a/consumer_test.go b/consumer_test.go index d53bca5..4712e3c 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -214,6 +214,46 @@ func TestScanShard_Cancellation(t *testing.T) { } } +func TestScanShard_SkipCheckpoint(t *testing.T) { + var client = &kinesisClientMock{ + getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: records, + }, nil + }, + } + + var cp = &fakeCheckpoint{cache: map[string]string{}} + + c, err := New("myStreamName", WithClient(client), WithCheckpoint(cp)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var fn = func(r *Record) error { + if aws.StringValue(r.SequenceNumber) == "lastSeqNum" { + return SkipCheckpoint + } + return nil + } + + err = c.ScanShard(context.Background(), "myShard", fn) + if err != nil { + t.Fatalf("scan shard error: %v", err) + } + + val, err := cp.Get("myStreamName", "myShard") + if err != nil && val != "firstSeqNum" { + t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val) + } +} + func TestScanShard_ShardIsClosed(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {