kinesis-consumer/consumer_test.go

340 lines
8.4 KiB
Go
Raw Normal View History

package consumer
import (
"context"
"fmt"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
func TestNew(t *testing.T) {
if _, err := New("myStreamName"); err != nil {
t.Fatalf("new consumer error: %v", err)
}
}
func TestScan(t *testing.T) {
client := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []*kinesis.Shard{
{ShardId: aws.String("myShard")},
},
}, nil
},
}
var (
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithStorage(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var (
ctx, cancel = context.WithCancel(context.Background())
res string
)
var fn = func(r *Record) error {
res += string(r.Data)
if string(r.Data) == "lastData" {
cancel()
}
return nil
}
if err := c.Scan(ctx, fn); err != nil {
t.Errorf("scan returned unexpected error %v", err)
}
if res != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "firstDatalastData", res)
}
if val := ctr.Get(); val != 2 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
val, err := cp.GetCheckpoint("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
}
}
func TestScanShard(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
}
var (
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithStorage(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
// callback fn appends record data
var (
ctx, cancel = context.WithCancel(context.Background())
res string
)
var fn = func(r *Record) error {
res += string(r.Data)
if string(r.Data) == "lastData" {
cancel()
}
return nil
}
if err := c.ScanShard(ctx, "myShard", fn); err != nil {
t.Errorf("scan returned unexpected error %v", err)
}
// runs callback func
if res != "firstDatalastData" {
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", res)
}
// increments counter
if val := ctr.Get(); val != 2 {
t.Fatalf("counter error expected %d, got %d", 2, val)
}
// sets checkpoint
val, err := cp.GetCheckpoint("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
}
}
func TestScanShard_Cancellation(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
2018-07-29 17:27:01 +00:00
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
2018-07-29 17:27:01 +00:00
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.ScanShard(ctx, "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
if res != "firstData" {
t.Fatalf("callback error expected %s, got %s", "firstData", res)
}
}
func TestScanShard_SkipCheckpoint(t *testing.T) {
2018-07-29 17:27:01 +00:00
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
}
var cp = &fakeCheckpoint{cache: map[string]string{}}
c, err := New("myStreamName", WithClient(client), WithStorage(cp))
2018-07-29 17:27:01 +00:00
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var ctx, cancel = context.WithCancel(context.Background())
var fn = func(r *Record) error {
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
cancel()
return ErrSkipCheckpoint
}
return nil
2018-07-29 17:27:01 +00:00
}
err = c.ScanShard(ctx, "myShard", fn)
if err != nil {
2018-07-29 17:27:01 +00:00
t.Fatalf("scan shard error: %v", err)
}
val, err := cp.GetCheckpoint("myStreamName", "myShard")
if err != nil && val != "firstSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
2018-07-29 17:27:01 +00:00
}
}
func TestScanShard_ShardIsClosed(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*Record, 0),
}, nil
},
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var fn = func(r *Record) error {
return nil
}
err = c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
}
type kinesisClientMock struct {
kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
listShardsMock func(*kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error)
}
func (c *kinesisClientMock) ListShards(in *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return c.listShardsMock(in)
}
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return c.getRecordsMock(in)
}
func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return c.getShardIteratorMock(in)
}
// implementation of checkpoint
type fakeCheckpoint struct {
cache map[string]string
mu sync.Mutex
}
func (fc *fakeCheckpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
fc.mu.Lock()
defer fc.mu.Unlock()
key := fmt.Sprintf("%s-%s", streamName, shardID)
fc.cache[key] = sequenceNumber
return nil
}
func (fc *fakeCheckpoint) GetCheckpoint(streamName, shardID string) (string, error) {
fc.mu.Lock()
defer fc.mu.Unlock()
key := fmt.Sprintf("%s-%s", streamName, shardID)
return fc.cache[key], nil
}
// implementation of counter
type fakeCounter struct {
counter int64
mu sync.Mutex
}
func (fc *fakeCounter) Get() int64 {
fc.mu.Lock()
defer fc.mu.Unlock()
return fc.counter
}
func (fc *fakeCounter) Add(streamName string, count int64) {
fc.mu.Lock()
defer fc.mu.Unlock()
fc.counter += count
}