Refactor the consumer tests

A previous PR from @vincent6767 had nicer mock Kinesis client that
simplified setting up data for the tests.

Mock client pulled from:
https://github.com/harlow/kinesis-consumer/pull/64
This commit is contained in:
Harlow Ward 2018-07-29 10:13:03 -07:00
parent 911282363e
commit d6ded158bf
4 changed files with 85 additions and 54 deletions

View file

@ -6,6 +6,7 @@ type Checkpoint interface {
Set(streamName, shardID, sequenceNumber string) error Set(streamName, shardID, sequenceNumber string) error
} }
// noopCheckpoint implements the checkpoint interface with discard
type noopCheckpoint struct{} type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string, string) error { return nil } func (n noopCheckpoint) Set(string, string, string) error { return nil }

View file

@ -3,8 +3,6 @@ package consumer
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"log"
"sync" "sync"
"testing" "testing"
@ -21,87 +19,119 @@ func TestNew(t *testing.T) {
} }
func TestScanShard(t *testing.T) { func TestScanShard(t *testing.T) {
var ( var records = []*kinesis.Record{
resultData string &kinesis.Record{
ckp = &fakeCheckpoint{cache: map[string]string{}} Data: []byte("firstData"),
ctr = &fakeCounter{} SequenceNumber: aws.String("firstSeqNum"),
mockSvc = &mockKinesisClient{} },
logger = &noopLogger{ &kinesis.Record{
logger: log.New(ioutil.Discard, "", log.LstdFlags), Data: []byte("lastData"),
} SequenceNumber: aws.String("lastSeqNum"),
) },
c := &Consumer{
streamName: "myStreamName",
client: mockSvc,
checkpoint: ckp,
counter: ctr,
logger: logger,
} }
var recordNum = 0 var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
}
// callback fn simply appends the record data to result string var (
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithCheckpoint(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var resultData string
// callback fn appends record data
var fn = func(r *Record) ScanStatus { var fn = func(r *Record) ScanStatus {
resultData += string(r.Data) resultData += string(r.Data)
recordNum++ return ScanStatus{}
stopScan := recordNum == 2
return ScanStatus{
StopScan: stopScan,
SkipCheckpoint: false,
}
} }
// scan shard // scan shard
err := c.ScanShard(context.Background(), "myShard", fn) if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
if err != nil {
t.Fatalf("scan shard error: %v", err) t.Fatalf("scan shard error: %v", err)
} }
// runs callback func
if resultData != "firstDatalastData" {
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", resultData)
}
// increments counter // increments counter
if val := ctr.counter; val != 2 { if val := ctr.counter; val != 2 {
t.Fatalf("counter error expected %d, got %d", 2, val) t.Fatalf("counter error expected %d, got %d", 2, val)
} }
// sets checkpoint // sets checkpoint
val, err := ckp.Get("myStreamName", "myShard") val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" { if err != nil && val != "lastSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val) t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
} }
}
// calls callback func func TestScanShard_ShardIsClosed(t *testing.T) {
if resultData != "firstDatalastData" { var client = &kinesisClientMock{
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", val) getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*Record, 0),
}, nil
},
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var fn = func(r *Record) ScanStatus {
return ScanStatus{}
}
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
t.Fatalf("scan shard error: %v", err)
} }
} }
type mockKinesisClient struct { type kinesisClientMock struct {
kinesisiface.KinesisAPI kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
} }
func (m *mockKinesisClient) GetRecords(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return c.getRecordsMock(in)
return &kinesis.GetRecordsOutput{
Records: []*kinesis.Record{
&kinesis.Record{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&kinesis.Record{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
},
}, nil
} }
func (m *mockKinesisClient) GetShardIterator(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return c.getShardIteratorMock(in)
ShardIterator: aws.String("myshard"),
}, nil
} }
// implementation of checkpoint
type fakeCheckpoint struct { type fakeCheckpoint struct {
cache map[string]string cache map[string]string
mu sync.Mutex mu sync.Mutex
@ -124,6 +154,7 @@ func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) {
return fc.cache[key], nil return fc.cache[key], nil
} }
// implementation of counter
type fakeCounter struct { type fakeCounter struct {
counter int64 counter int64
} }

View file

@ -5,6 +5,7 @@ type Counter interface {
Add(string, int64) Add(string, int64)
} }
// noopCounter implements counter interface with discard
type noopCounter struct{} type noopCounter struct{}
func (n noopCounter) Add(string, int64) {} func (n noopCounter) Add(string, int64) {}

View file

@ -9,8 +9,6 @@ type Logger interface {
Log(...interface{}) Log(...interface{})
} }
type LoggerFunc func(...interface{})
// noopLogger implements logger interface with discard // noopLogger implements logger interface with discard
type noopLogger struct { type noopLogger struct {
logger *log.Logger logger *log.Logger