kinesis-consumer/consumer_test.go

134 lines
2.9 KiB
Go
Raw Normal View History

package consumer
import (
"context"
"fmt"
"io/ioutil"
"log"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
func TestNew(t *testing.T) {
_, err := New("myStreamName")
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
}
func TestScanShard(t *testing.T) {
var (
resultData string
ckp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
mockSvc = &mockKinesisClient{}
logger = &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
}
)
c := &Consumer{
streamName: "myStreamName",
client: mockSvc,
checkpoint: ckp,
counter: ctr,
logger: logger,
}
var recordNum = 0
// callback fn simply appends the record data to result string
var fn = func(r *Record) ScanStatus {
resultData += string(r.Data)
recordNum++
stopScan := recordNum == 2
return ScanStatus{
StopScan: stopScan,
SkipCheckpoint: false,
}
}
// scan shard
err := c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
// increments counter
if val := ctr.counter; val != 2 {
t.Fatalf("counter error expected %d, got %d", 2, val)
}
// sets checkpoint
val, err := ckp.Get("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
}
// calls callback func
if resultData != "firstDatalastData" {
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", val)
}
}
type mockKinesisClient struct {
kinesisiface.KinesisAPI
}
func (m *mockKinesisClient) GetRecords(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
Records: []*kinesis.Record{
&kinesis.Record{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&kinesis.Record{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
},
}, nil
}
func (m *mockKinesisClient) GetShardIterator(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("myshard"),
}, nil
}
type fakeCheckpoint struct {
cache map[string]string
mu sync.Mutex
}
func (fc *fakeCheckpoint) Set(streamName, shardID, sequenceNumber string) error {
fc.mu.Lock()
defer fc.mu.Unlock()
key := fmt.Sprintf("%s-%s", streamName, shardID)
fc.cache[key] = sequenceNumber
return nil
}
func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) {
fc.mu.Lock()
defer fc.mu.Unlock()
key := fmt.Sprintf("%s-%s", streamName, shardID)
return fc.cache[key], nil
}
type fakeCounter struct {
counter int64
}
func (fc *fakeCounter) Add(streamName string, count int64) {
fc.counter += count
}