From d6ded158bf704339cd54c44fd96091222bf6a528 Mon Sep 17 00:00:00 2001 From: Harlow Ward Date: Sun, 29 Jul 2018 10:13:03 -0700 Subject: [PATCH] Refactor the consumer tests A previous PR from @vincent6767 had nicer mock Kinesis client that simplified setting up data for the tests. Mock client pulled from: https://github.com/harlow/kinesis-consumer/pull/64 --- checkpoint.go | 1 + consumer_test.go | 135 +++++++++++++++++++++++++++++------------------ counter.go | 1 + logger.go | 2 - 4 files changed, 85 insertions(+), 54 deletions(-) diff --git a/checkpoint.go b/checkpoint.go index 5a22a90..383d4c1 100644 --- a/checkpoint.go +++ b/checkpoint.go @@ -6,6 +6,7 @@ type Checkpoint interface { Set(streamName, shardID, sequenceNumber string) error } +// noopCheckpoint implements the checkpoint interface with discard type noopCheckpoint struct{} func (n noopCheckpoint) Set(string, string, string) error { return nil } diff --git a/consumer_test.go b/consumer_test.go index 07be3e8..8e02caa 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -3,8 +3,6 @@ package consumer import ( "context" "fmt" - "io/ioutil" - "log" "sync" "testing" @@ -21,87 +19,119 @@ func TestNew(t *testing.T) { } func TestScanShard(t *testing.T) { - var ( - resultData string - ckp = &fakeCheckpoint{cache: map[string]string{}} - ctr = &fakeCounter{} - mockSvc = &mockKinesisClient{} - logger = &noopLogger{ - logger: log.New(ioutil.Discard, "", log.LstdFlags), - } - ) - - c := &Consumer{ - streamName: "myStreamName", - client: mockSvc, - checkpoint: ckp, - counter: ctr, - logger: logger, + var records = []*kinesis.Record{ + &kinesis.Record{ + Data: []byte("firstData"), + SequenceNumber: aws.String("firstSeqNum"), + }, + &kinesis.Record{ + Data: []byte("lastData"), + SequenceNumber: aws.String("lastSeqNum"), + }, } - var recordNum = 0 + var client = &kinesisClientMock{ + getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: records, + }, nil + }, + } - // callback fn simply appends the record data to result string + var ( + cp = &fakeCheckpoint{cache: map[string]string{}} + ctr = &fakeCounter{} + ) + + c, err := New("myStreamName", + WithClient(client), + WithCounter(ctr), + WithCheckpoint(cp), + ) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var resultData string + + // callback fn appends record data var fn = func(r *Record) ScanStatus { resultData += string(r.Data) - recordNum++ - stopScan := recordNum == 2 - - return ScanStatus{ - StopScan: stopScan, - SkipCheckpoint: false, - } + return ScanStatus{} } // scan shard - err := c.ScanShard(context.Background(), "myShard", fn) - if err != nil { + if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { t.Fatalf("scan shard error: %v", err) } + // runs callback func + if resultData != "firstDatalastData" { + t.Fatalf("callback error expected %s, got %s", "firstDatalastData", resultData) + } + // increments counter if val := ctr.counter; val != 2 { t.Fatalf("counter error expected %d, got %d", 2, val) } // sets checkpoint - val, err := ckp.Get("myStreamName", "myShard") + val, err := cp.Get("myStreamName", "myShard") if err != nil && val != "lastSeqNum" { t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val) } +} - // calls callback func - if resultData != "firstDatalastData" { - t.Fatalf("callback error expected %s, got %s", "firstDatalastData", val) +func TestScanShard_ShardIsClosed(t *testing.T) { + var client = &kinesisClientMock{ + getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: make([]*Record, 0), + }, nil + }, + } + + c, err := New("myStreamName", WithClient(client)) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var fn = func(r *Record) ScanStatus { + return ScanStatus{} + } + + if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { + t.Fatalf("scan shard error: %v", err) } } -type mockKinesisClient struct { +type kinesisClientMock struct { kinesisiface.KinesisAPI + getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) + getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) } -func (m *mockKinesisClient) GetRecords(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { - - return &kinesis.GetRecordsOutput{ - Records: []*kinesis.Record{ - &kinesis.Record{ - Data: []byte("firstData"), - SequenceNumber: aws.String("firstSeqNum"), - }, - &kinesis.Record{ - Data: []byte("lastData"), - SequenceNumber: aws.String("lastSeqNum"), - }, - }, - }, nil +func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + return c.getRecordsMock(in) } -func (m *mockKinesisClient) GetShardIterator(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { - return &kinesis.GetShardIteratorOutput{ - ShardIterator: aws.String("myshard"), - }, nil +func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + return c.getShardIteratorMock(in) } +// implementation of checkpoint type fakeCheckpoint struct { cache map[string]string mu sync.Mutex @@ -124,6 +154,7 @@ func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) { return fc.cache[key], nil } +// implementation of counter type fakeCounter struct { counter int64 } diff --git a/counter.go b/counter.go index 82a15c1..f33a8e5 100644 --- a/counter.go +++ b/counter.go @@ -5,6 +5,7 @@ type Counter interface { Add(string, int64) } +// noopCounter implements counter interface with discard type noopCounter struct{} func (n noopCounter) Add(string, int64) {} diff --git a/logger.go b/logger.go index bd28361..ab90d2a 100644 --- a/logger.go +++ b/logger.go @@ -9,8 +9,6 @@ type Logger interface { Log(...interface{}) } -type LoggerFunc func(...interface{}) - // noopLogger implements logger interface with discard type noopLogger struct { logger *log.Logger