2017-11-27 00:00:11 +00:00
|
|
|
package consumer
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"fmt"
|
|
|
|
|
"sync"
|
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
2018-11-07 23:45:13 +00:00
|
|
|
"github.com/aws/aws-sdk-go/aws/request"
|
2018-07-29 05:53:33 +00:00
|
|
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
|
|
|
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
2017-11-27 00:00:11 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func TestNew(t *testing.T) {
|
|
|
|
|
_, err := New("myStreamName")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("new consumer error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-09-03 16:59:39 +00:00
|
|
|
func TestConsumer_Scan(t *testing.T) {
|
2018-11-07 23:45:13 +00:00
|
|
|
ctx := context.TODO()
|
2018-09-03 16:59:39 +00:00
|
|
|
records := []*kinesis.Record{
|
|
|
|
|
{
|
|
|
|
|
Data: []byte("firstData"),
|
|
|
|
|
SequenceNumber: aws.String("firstSeqNum"),
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
Data: []byte("lastData"),
|
|
|
|
|
SequenceNumber: aws.String("lastSeqNum"),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
client := &kinesisClientMock{
|
2018-11-07 23:45:13 +00:00
|
|
|
getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
|
2018-09-03 16:59:39 +00:00
|
|
|
return &kinesis.GetShardIteratorOutput{
|
|
|
|
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
2018-11-07 23:45:13 +00:00
|
|
|
getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
|
2018-09-03 16:59:39 +00:00
|
|
|
return &kinesis.GetRecordsOutput{
|
|
|
|
|
NextShardIterator: nil,
|
|
|
|
|
Records: records,
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
2018-11-07 23:45:13 +00:00
|
|
|
describeStreamMock: func(a aws.Context, input *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) {
|
2018-09-03 16:59:39 +00:00
|
|
|
return &kinesis.DescribeStreamOutput{
|
|
|
|
|
StreamDescription: &kinesis.StreamDescription{
|
|
|
|
|
Shards: []*kinesis.Shard{
|
|
|
|
|
{ShardId: aws.String("myShard")},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
var (
|
|
|
|
|
cp = &fakeCheckpoint{cache: map[string]string{}}
|
|
|
|
|
ctr = &fakeCounter{}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
c, err := New("myStreamName",
|
|
|
|
|
WithClient(client),
|
|
|
|
|
WithCounter(ctr),
|
|
|
|
|
WithCheckpoint(cp),
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("new consumer error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var resultData string
|
|
|
|
|
var fnCallCounter int
|
|
|
|
|
var fn = func(r *Record) ScanStatus {
|
|
|
|
|
fnCallCounter++
|
|
|
|
|
resultData += string(r.Data)
|
|
|
|
|
return ScanStatus{}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
if err := c.Scan(ctx, fn); err != nil {
|
2018-09-03 16:59:39 +00:00
|
|
|
t.Errorf("scan shard error expected nil. got %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if resultData != "firstDatalastData" {
|
|
|
|
|
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
|
|
|
|
|
}
|
|
|
|
|
if fnCallCounter != 2 {
|
|
|
|
|
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
|
|
|
|
}
|
|
|
|
|
if val := ctr.counter; val != 2 {
|
|
|
|
|
t.Errorf("counter error expected %d, got %d", 2, val)
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
val, err := cp.Get(ctx, "myStreamName", "myShard")
|
2018-09-03 16:59:39 +00:00
|
|
|
if err != nil && val != "lastSeqNum" {
|
|
|
|
|
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
2018-11-07 23:45:13 +00:00
|
|
|
ctx := context.TODO()
|
2018-09-03 16:59:39 +00:00
|
|
|
client := &kinesisClientMock{
|
2018-11-07 23:45:13 +00:00
|
|
|
describeStreamMock: func(a aws.Context, input *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) {
|
2018-09-03 16:59:39 +00:00
|
|
|
return &kinesis.DescribeStreamOutput{
|
|
|
|
|
StreamDescription: &kinesis.StreamDescription{
|
|
|
|
|
Shards: make([]*kinesis.Shard, 0),
|
|
|
|
|
},
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
var (
|
|
|
|
|
cp = &fakeCheckpoint{cache: map[string]string{}}
|
|
|
|
|
ctr = &fakeCounter{}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
c, err := New("myStreamName",
|
|
|
|
|
WithClient(client),
|
|
|
|
|
WithCounter(ctr),
|
|
|
|
|
WithCheckpoint(cp),
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("new consumer error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var fnCallCounter int
|
|
|
|
|
var fn = func(r *Record) ScanStatus {
|
|
|
|
|
fnCallCounter++
|
|
|
|
|
return ScanStatus{}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
if err := c.Scan(ctx, fn); err == nil {
|
2018-09-03 16:59:39 +00:00
|
|
|
t.Errorf("scan shard error expected not nil. got %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if fnCallCounter != 0 {
|
|
|
|
|
t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter)
|
|
|
|
|
}
|
|
|
|
|
if val := ctr.counter; val != 0 {
|
|
|
|
|
t.Errorf("counter error expected %d, got %d", 0, val)
|
|
|
|
|
}
|
2018-11-07 23:45:13 +00:00
|
|
|
val, err := cp.Get(ctx, "myStreamName", "myShard")
|
2018-09-03 16:59:39 +00:00
|
|
|
if err != nil && val != "" {
|
|
|
|
|
t.Errorf("checkout error expected %s, got %s", "", val)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2017-11-27 00:00:11 +00:00
|
|
|
func TestScanShard(t *testing.T) {
|
2018-11-07 23:45:13 +00:00
|
|
|
ctx := context.TODO()
|
2018-07-29 17:13:03 +00:00
|
|
|
var records = []*kinesis.Record{
|
2018-09-03 16:59:39 +00:00
|
|
|
{
|
2018-07-29 17:13:03 +00:00
|
|
|
Data: []byte("firstData"),
|
|
|
|
|
SequenceNumber: aws.String("firstSeqNum"),
|
|
|
|
|
},
|
2018-09-03 16:59:39 +00:00
|
|
|
{
|
2018-07-29 17:13:03 +00:00
|
|
|
Data: []byte("lastData"),
|
|
|
|
|
SequenceNumber: aws.String("lastSeqNum"),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var client = &kinesisClientMock{
|
2018-11-07 23:45:13 +00:00
|
|
|
getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
|
2018-07-29 17:13:03 +00:00
|
|
|
return &kinesis.GetShardIteratorOutput{
|
|
|
|
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
2018-11-07 23:45:13 +00:00
|
|
|
getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
|
2018-07-29 17:13:03 +00:00
|
|
|
return &kinesis.GetRecordsOutput{
|
|
|
|
|
NextShardIterator: nil,
|
|
|
|
|
Records: records,
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
2017-11-27 00:00:11 +00:00
|
|
|
var (
|
2018-07-29 17:13:03 +00:00
|
|
|
cp = &fakeCheckpoint{cache: map[string]string{}}
|
|
|
|
|
ctr = &fakeCounter{}
|
2017-11-27 00:00:11 +00:00
|
|
|
)
|
|
|
|
|
|
2018-07-29 17:13:03 +00:00
|
|
|
c, err := New("myStreamName",
|
|
|
|
|
WithClient(client),
|
|
|
|
|
WithCounter(ctr),
|
|
|
|
|
WithCheckpoint(cp),
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("new consumer error: %v", err)
|
2017-11-27 00:00:11 +00:00
|
|
|
}
|
|
|
|
|
|
2018-07-29 17:13:03 +00:00
|
|
|
// callback fn appends record data
|
2018-07-29 17:27:01 +00:00
|
|
|
var resultData string
|
2018-07-29 05:53:33 +00:00
|
|
|
var fn = func(r *Record) ScanStatus {
|
|
|
|
|
resultData += string(r.Data)
|
2018-07-29 17:13:03 +00:00
|
|
|
return ScanStatus{}
|
2018-07-29 05:53:33 +00:00
|
|
|
}
|
2017-11-27 00:00:11 +00:00
|
|
|
|
|
|
|
|
// scan shard
|
2018-11-07 23:45:13 +00:00
|
|
|
if err := c.ScanShard(ctx, "myShard", fn); err != nil {
|
2017-11-27 00:00:11 +00:00
|
|
|
t.Fatalf("scan shard error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
2018-07-29 17:13:03 +00:00
|
|
|
// runs callback func
|
|
|
|
|
if resultData != "firstDatalastData" {
|
|
|
|
|
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", resultData)
|
|
|
|
|
}
|
|
|
|
|
|
2017-11-27 00:00:11 +00:00
|
|
|
// increments counter
|
|
|
|
|
if val := ctr.counter; val != 2 {
|
|
|
|
|
t.Fatalf("counter error expected %d, got %d", 2, val)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// sets checkpoint
|
2018-11-07 23:45:13 +00:00
|
|
|
val, err := cp.Get(ctx, "myStreamName", "myShard")
|
2017-11-27 00:00:11 +00:00
|
|
|
if err != nil && val != "lastSeqNum" {
|
|
|
|
|
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
|
|
|
|
|
}
|
2018-07-29 17:13:03 +00:00
|
|
|
}
|
2017-11-27 00:00:11 +00:00
|
|
|
|
2018-07-29 17:27:01 +00:00
|
|
|
func TestScanShard_StopScan(t *testing.T) {
|
2018-11-07 23:45:13 +00:00
|
|
|
ctx := context.TODO()
|
2018-07-29 17:27:01 +00:00
|
|
|
var records = []*kinesis.Record{
|
2018-09-03 16:59:39 +00:00
|
|
|
{
|
2018-07-29 17:27:01 +00:00
|
|
|
Data: []byte("firstData"),
|
|
|
|
|
SequenceNumber: aws.String("firstSeqNum"),
|
|
|
|
|
},
|
2018-09-03 16:59:39 +00:00
|
|
|
{
|
2018-07-29 17:27:01 +00:00
|
|
|
Data: []byte("lastData"),
|
|
|
|
|
SequenceNumber: aws.String("lastSeqNum"),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var client = &kinesisClientMock{
|
2018-11-07 23:45:13 +00:00
|
|
|
getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
|
2018-07-29 17:27:01 +00:00
|
|
|
return &kinesis.GetShardIteratorOutput{
|
|
|
|
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
2018-11-07 23:45:13 +00:00
|
|
|
getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
|
2018-07-29 17:27:01 +00:00
|
|
|
return &kinesis.GetRecordsOutput{
|
|
|
|
|
NextShardIterator: nil,
|
|
|
|
|
Records: records,
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c, err := New("myStreamName", WithClient(client))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("new consumer error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// callback fn appends record data
|
|
|
|
|
var resultData string
|
|
|
|
|
var fn = func(r *Record) ScanStatus {
|
|
|
|
|
resultData += string(r.Data)
|
|
|
|
|
return ScanStatus{StopScan: true}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
if err := c.ScanShard(ctx, "myShard", fn); err != nil {
|
2018-07-29 17:27:01 +00:00
|
|
|
t.Fatalf("scan shard error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if resultData != "firstData" {
|
|
|
|
|
t.Fatalf("callback error expected %s, got %s", "firstData", resultData)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-07-29 17:13:03 +00:00
|
|
|
func TestScanShard_ShardIsClosed(t *testing.T) {
|
2018-11-07 23:45:13 +00:00
|
|
|
ctx := context.TODO()
|
2018-07-29 17:13:03 +00:00
|
|
|
var client = &kinesisClientMock{
|
2018-11-07 23:45:13 +00:00
|
|
|
getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
|
2018-07-29 17:13:03 +00:00
|
|
|
return &kinesis.GetShardIteratorOutput{
|
|
|
|
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
2018-11-07 23:45:13 +00:00
|
|
|
getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
|
2018-07-29 17:13:03 +00:00
|
|
|
return &kinesis.GetRecordsOutput{
|
|
|
|
|
NextShardIterator: nil,
|
|
|
|
|
Records: make([]*Record, 0),
|
|
|
|
|
}, nil
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c, err := New("myStreamName", WithClient(client))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("new consumer error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var fn = func(r *Record) ScanStatus {
|
|
|
|
|
return ScanStatus{}
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
if err := c.ScanShard(ctx, "myShard", fn); err != nil {
|
2018-07-29 17:13:03 +00:00
|
|
|
t.Fatalf("scan shard error: %v", err)
|
2017-11-27 00:00:11 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-07-29 17:13:03 +00:00
|
|
|
type kinesisClientMock struct {
|
2018-07-29 05:53:33 +00:00
|
|
|
kinesisiface.KinesisAPI
|
2018-11-07 23:45:13 +00:00
|
|
|
getShardIteratorMock func(aws.Context, *kinesis.GetShardIteratorInput, ...request.Option) (*kinesis.GetShardIteratorOutput, error)
|
|
|
|
|
getRecordsMock func(aws.Context, *kinesis.GetRecordsInput, ...request.Option) (*kinesis.GetRecordsOutput, error)
|
|
|
|
|
describeStreamMock func(aws.Context, *kinesis.DescribeStreamInput, ...request.Option) (*kinesis.DescribeStreamOutput, error)
|
2017-11-27 00:00:11 +00:00
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
func (c *kinesisClientMock) GetRecordsWithContext(a aws.Context, in *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
|
|
|
|
|
return c.getRecordsMock(a, in, o...)
|
2017-11-27 00:00:11 +00:00
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
func (c *kinesisClientMock) GetShardIteratorWithContext(a aws.Context, in *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
|
|
|
|
|
return c.getShardIteratorMock(a, in, o...)
|
2017-11-27 00:00:11 +00:00
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
func (c *kinesisClientMock) DescribeStreamWithContext(a aws.Context, in *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) {
|
|
|
|
|
return c.describeStreamMock(a, in, o...)
|
2018-09-03 16:59:39 +00:00
|
|
|
}
|
|
|
|
|
|
2018-07-29 17:13:03 +00:00
|
|
|
// implementation of checkpoint
|
2017-11-27 00:00:11 +00:00
|
|
|
type fakeCheckpoint struct {
|
|
|
|
|
cache map[string]string
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
func (fc *fakeCheckpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error {
|
2017-11-27 00:00:11 +00:00
|
|
|
fc.mu.Lock()
|
|
|
|
|
defer fc.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
key := fmt.Sprintf("%s-%s", streamName, shardID)
|
|
|
|
|
fc.cache[key] = sequenceNumber
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2018-11-07 23:45:13 +00:00
|
|
|
func (fc *fakeCheckpoint) Get(ctx context.Context, streamName, shardID string) (string, error) {
|
2017-11-27 00:00:11 +00:00
|
|
|
fc.mu.Lock()
|
|
|
|
|
defer fc.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
key := fmt.Sprintf("%s-%s", streamName, shardID)
|
|
|
|
|
return fc.cache[key], nil
|
|
|
|
|
}
|
|
|
|
|
|
2018-07-29 17:13:03 +00:00
|
|
|
// implementation of counter
|
2017-11-27 00:00:11 +00:00
|
|
|
type fakeCounter struct {
|
|
|
|
|
counter int64
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (fc *fakeCounter) Add(streamName string, count int64) {
|
|
|
|
|
fc.counter += count
|
|
|
|
|
}
|