Add unit tests for the GetRecords function (#64)
A closed shard returns the last sequence number and no error. The current implementation leads to an infinite loop if the shard is closed. NextShardIterator checking is enough. That's why I remove the getShardIterator call
This commit is contained in:
parent
a1239221d8
commit
049445e259
2 changed files with 153 additions and 9 deletions
12
client.go
12
client.go
|
|
@ -108,7 +108,6 @@ func (c *KinesisClient) GetRecords(ctx context.Context, streamName, shardID, las
|
|||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, r := range resp.Records {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
@ -117,16 +116,11 @@ func (c *KinesisClient) GetRecords(ctx context.Context, streamName, shardID, las
|
|||
lastSeqNum = *r.SequenceNumber
|
||||
}
|
||||
}
|
||||
|
||||
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
|
||||
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("get shard iterator error: %v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
shardIterator = resp.NextShardIterator
|
||||
errc <- fmt.Errorf("get shard iterator error: %v", err)
|
||||
return
|
||||
}
|
||||
shardIterator = resp.NextShardIterator
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
|
|
|||
150
client_test.go
Normal file
150
client_test.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
package consumer_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"context"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||
"github.com/harlow/kinesis-consumer"
|
||||
)
|
||||
|
||||
func TestKinesisClient_GetRecords_SuccessfullyRun(t *testing.T) {
|
||||
kinesisClient := &kinesisClientMock{
|
||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return &kinesis.GetShardIteratorOutput{
|
||||
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||
}, nil
|
||||
},
|
||||
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: make([]*kinesis.Record, 0),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
|
||||
c, err := consumer.NewKinesisClient(kinesisClientOpt)
|
||||
if err != nil {
|
||||
t.Fatalf("New kinesis client error: %v", err)
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
recordsChan, errorsChan, err := c.GetRecords(ctx, "myStream", "shardId-000000000000", "")
|
||||
|
||||
if recordsChan == nil {
|
||||
t.Errorf("records channel expected not nil, got %v", recordsChan)
|
||||
}
|
||||
if errorsChan == nil {
|
||||
t.Errorf("errors channel expected not nil, got %v", recordsChan)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("error expected nil, got %v", err)
|
||||
}
|
||||
|
||||
cancelFunc()
|
||||
}
|
||||
|
||||
func TestKinesisClient_GetRecords_SuccessfullyRetrievesThreeRecordsAtOnce(t *testing.T) {
|
||||
expectedResults := []*kinesis.Record{
|
||||
{
|
||||
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576195"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576196"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576197"),
|
||||
}}
|
||||
kinesisClient := &kinesisClientMock{
|
||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return &kinesis.GetShardIteratorOutput{
|
||||
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||
}, nil
|
||||
},
|
||||
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: expectedResults,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
|
||||
c, err := consumer.NewKinesisClient(kinesisClientOpt)
|
||||
if err != nil {
|
||||
t.Fatalf("new kinesis client error: %v", err)
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
recordsChan, _, err := c.GetRecords(ctx, "TestStream", "shardId-000000000000", "")
|
||||
|
||||
if recordsChan == nil {
|
||||
t.Fatalf("records channel expected not nil, got %v", recordsChan)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("error expected nil, got %v", err)
|
||||
}
|
||||
var results []*consumer.Record
|
||||
results = append(results, <-recordsChan, <-recordsChan, <-recordsChan)
|
||||
if len(results) != 3 {
|
||||
t.Errorf("number of records expected 3, got %v", len(results))
|
||||
}
|
||||
for i, r := range results {
|
||||
if r != expectedResults[i] {
|
||||
t.Errorf("record expected %v, got %v", expectedResults[i], r)
|
||||
}
|
||||
}
|
||||
|
||||
cancelFunc()
|
||||
}
|
||||
|
||||
func TestKinesisClient_GetRecords_ShardIsClosed(t *testing.T) {
|
||||
kinesisClient := &kinesisClientMock{
|
||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return &kinesis.GetShardIteratorOutput{
|
||||
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||
}, nil
|
||||
},
|
||||
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: make([]*consumer.Record, 0),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
|
||||
c, err := consumer.NewKinesisClient(kinesisClientOpt)
|
||||
if err != nil {
|
||||
t.Fatalf("new kinesis client error: %v", err)
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
_, errorsChan, err := c.GetRecords(ctx, "TestStream", "shardId-000000000000", "")
|
||||
|
||||
if errorsChan == nil {
|
||||
t.Fatalf("errors channel expected equals not nil, got %v", errorsChan)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("error expected, got %v", err)
|
||||
}
|
||||
|
||||
err = <-errorsChan
|
||||
if err == nil {
|
||||
t.Errorf("error expected, got %v", err)
|
||||
}
|
||||
|
||||
cancelFunc()
|
||||
}
|
||||
|
||||
type kinesisClientMock struct {
|
||||
kinesisiface.KinesisAPI
|
||||
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
|
||||
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
|
||||
}
|
||||
|
||||
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||
return c.getRecordsMock(in)
|
||||
}
|
||||
|
||||
func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return c.getShardIteratorMock(in)
|
||||
}
|
||||
Loading…
Reference in a new issue