Refactor the consumer tests
A previous PR from @vincent6767 had nicer mock Kinesis client that simplified setting up data for the tests. Mock client pulled from: https://github.com/harlow/kinesis-consumer/pull/64
This commit is contained in:
parent
911282363e
commit
d6ded158bf
4 changed files with 85 additions and 54 deletions
|
|
@ -6,6 +6,7 @@ type Checkpoint interface {
|
||||||
Set(streamName, shardID, sequenceNumber string) error
|
Set(streamName, shardID, sequenceNumber string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// noopCheckpoint implements the checkpoint interface with discard
|
||||||
type noopCheckpoint struct{}
|
type noopCheckpoint struct{}
|
||||||
|
|
||||||
func (n noopCheckpoint) Set(string, string, string) error { return nil }
|
func (n noopCheckpoint) Set(string, string, string) error { return nil }
|
||||||
|
|
|
||||||
135
consumer_test.go
135
consumer_test.go
|
|
@ -3,8 +3,6 @@ package consumer
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
|
@ -21,87 +19,119 @@ func TestNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanShard(t *testing.T) {
|
func TestScanShard(t *testing.T) {
|
||||||
var (
|
var records = []*kinesis.Record{
|
||||||
resultData string
|
&kinesis.Record{
|
||||||
ckp = &fakeCheckpoint{cache: map[string]string{}}
|
Data: []byte("firstData"),
|
||||||
ctr = &fakeCounter{}
|
SequenceNumber: aws.String("firstSeqNum"),
|
||||||
mockSvc = &mockKinesisClient{}
|
},
|
||||||
logger = &noopLogger{
|
&kinesis.Record{
|
||||||
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
Data: []byte("lastData"),
|
||||||
}
|
SequenceNumber: aws.String("lastSeqNum"),
|
||||||
)
|
},
|
||||||
|
|
||||||
c := &Consumer{
|
|
||||||
streamName: "myStreamName",
|
|
||||||
client: mockSvc,
|
|
||||||
checkpoint: ckp,
|
|
||||||
counter: ctr,
|
|
||||||
logger: logger,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var recordNum = 0
|
var client = &kinesisClientMock{
|
||||||
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: records,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// callback fn simply appends the record data to result string
|
var (
|
||||||
|
cp = &fakeCheckpoint{cache: map[string]string{}}
|
||||||
|
ctr = &fakeCounter{}
|
||||||
|
)
|
||||||
|
|
||||||
|
c, err := New("myStreamName",
|
||||||
|
WithClient(client),
|
||||||
|
WithCounter(ctr),
|
||||||
|
WithCheckpoint(cp),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resultData string
|
||||||
|
|
||||||
|
// callback fn appends record data
|
||||||
var fn = func(r *Record) ScanStatus {
|
var fn = func(r *Record) ScanStatus {
|
||||||
resultData += string(r.Data)
|
resultData += string(r.Data)
|
||||||
recordNum++
|
return ScanStatus{}
|
||||||
stopScan := recordNum == 2
|
|
||||||
|
|
||||||
return ScanStatus{
|
|
||||||
StopScan: stopScan,
|
|
||||||
SkipCheckpoint: false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// scan shard
|
// scan shard
|
||||||
err := c.ScanShard(context.Background(), "myShard", fn)
|
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("scan shard error: %v", err)
|
t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runs callback func
|
||||||
|
if resultData != "firstDatalastData" {
|
||||||
|
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", resultData)
|
||||||
|
}
|
||||||
|
|
||||||
// increments counter
|
// increments counter
|
||||||
if val := ctr.counter; val != 2 {
|
if val := ctr.counter; val != 2 {
|
||||||
t.Fatalf("counter error expected %d, got %d", 2, val)
|
t.Fatalf("counter error expected %d, got %d", 2, val)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sets checkpoint
|
// sets checkpoint
|
||||||
val, err := ckp.Get("myStreamName", "myShard")
|
val, err := cp.Get("myStreamName", "myShard")
|
||||||
if err != nil && val != "lastSeqNum" {
|
if err != nil && val != "lastSeqNum" {
|
||||||
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
|
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// calls callback func
|
func TestScanShard_ShardIsClosed(t *testing.T) {
|
||||||
if resultData != "firstDatalastData" {
|
var client = &kinesisClientMock{
|
||||||
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", val)
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: make([]*Record, 0),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := New("myStreamName", WithClient(client))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fn = func(r *Record) ScanStatus {
|
||||||
|
return ScanStatus{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
||||||
|
t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockKinesisClient struct {
|
type kinesisClientMock struct {
|
||||||
kinesisiface.KinesisAPI
|
kinesisiface.KinesisAPI
|
||||||
|
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
|
||||||
|
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockKinesisClient) GetRecords(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
return c.getRecordsMock(in)
|
||||||
return &kinesis.GetRecordsOutput{
|
|
||||||
Records: []*kinesis.Record{
|
|
||||||
&kinesis.Record{
|
|
||||||
Data: []byte("firstData"),
|
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
|
||||||
},
|
|
||||||
&kinesis.Record{
|
|
||||||
Data: []byte("lastData"),
|
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockKinesisClient) GetShardIterator(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return &kinesis.GetShardIteratorOutput{
|
return c.getShardIteratorMock(in)
|
||||||
ShardIterator: aws.String("myshard"),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// implementation of checkpoint
|
||||||
type fakeCheckpoint struct {
|
type fakeCheckpoint struct {
|
||||||
cache map[string]string
|
cache map[string]string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
@ -124,6 +154,7 @@ func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) {
|
||||||
return fc.cache[key], nil
|
return fc.cache[key], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// implementation of counter
|
||||||
type fakeCounter struct {
|
type fakeCounter struct {
|
||||||
counter int64
|
counter int64
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ type Counter interface {
|
||||||
Add(string, int64)
|
Add(string, int64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// noopCounter implements counter interface with discard
|
||||||
type noopCounter struct{}
|
type noopCounter struct{}
|
||||||
|
|
||||||
func (n noopCounter) Add(string, int64) {}
|
func (n noopCounter) Add(string, int64) {}
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,6 @@ type Logger interface {
|
||||||
Log(...interface{})
|
Log(...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
type LoggerFunc func(...interface{})
|
|
||||||
|
|
||||||
// noopLogger implements logger interface with discard
|
// noopLogger implements logger interface with discard
|
||||||
type noopLogger struct {
|
type noopLogger struct {
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue