Break down the big function and Add tests for Scan (#65)

This commit is contained in:
Vincent 2018-09-03 23:59:39 +07:00 committed by Harlow Ward
parent 23811ec99a
commit d3b76346f5
2 changed files with 172 additions and 25 deletions

View file

@ -79,7 +79,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
// default client if none provided
if c.client == nil {
c.client = kinesis.New(session.New(aws.NewConfig()))
newSession, err := session.NewSession(aws.NewConfig())
if err != nil {
return nil, err
}
c.client = kinesis.New(newSession)
}
return c, nil
@ -161,7 +165,10 @@ func (c *Consumer) ScanShard(
c.logger.Log("scanning", shardID, lastSeqNum)
// scan pages of shard
return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn)
}
func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error {
for {
select {
case <-ctx.Done():
@ -181,28 +188,17 @@ func (c *Consumer) ScanShard(
// loop records of page
for _, r := range resp.Records {
status := fn(r)
if !status.SkipCheckpoint {
lastSeqNum = *r.SequenceNumber
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
return err
}
}
if err := status.Error; err != nil {
isScanStopped, err := c.handleRecord(shardID, r, fn)
if err != nil {
return err
}
c.counter.Add("records", 1)
if status.StopScan {
if isScanStopped {
return nil
}
lastSeqNum = *r.SequenceNumber
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
if isShardClosed(resp.NextShardIterator, shardIterator) {
return nil
}
shardIterator = resp.NextShardIterator
@ -210,6 +206,31 @@ func (c *Consumer) ScanShard(
}
}
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}
func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) {
status := fn(r)
if !status.SkipCheckpoint {
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return false, err
}
}
if err := status.Error; err != nil {
return false, err
}
c.counter.Add("records", 1)
if status.StopScan {
return true, nil
}
return false, nil
}
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
@ -220,7 +241,7 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
return nil, fmt.Errorf("describe stream error: %v", err)
}
ss := []string{}
var ss []string
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}

View file

@ -18,13 +18,134 @@ func TestNew(t *testing.T) {
}
}
func TestScanShard(t *testing.T) {
var records = []*kinesis.Record{
&kinesis.Record{
func TestConsumer_Scan(t *testing.T) {
records := []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&kinesis.Record{
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
client := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return &kinesis.DescribeStreamOutput{
StreamDescription: &kinesis.StreamDescription{
Shards: []*kinesis.Shard{
{ShardId: aws.String("myShard")},
},
},
}, nil
},
}
var (
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithCheckpoint(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var resultData string
var fnCallCounter int
var fn = func(r *Record) ScanStatus {
fnCallCounter++
resultData += string(r.Data)
return ScanStatus{}
}
if err := c.Scan(context.Background(), fn); err != nil {
t.Errorf("scan shard error expected nil. got %v", err)
}
if resultData != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
}
if fnCallCounter != 2 {
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
}
if val := ctr.counter; val != 2 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
}
}
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
client := &kinesisClientMock{
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return &kinesis.DescribeStreamOutput{
StreamDescription: &kinesis.StreamDescription{
Shards: make([]*kinesis.Shard, 0),
},
}, nil
},
}
var (
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithCheckpoint(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var fnCallCounter int
var fn = func(r *Record) ScanStatus {
fnCallCounter++
return ScanStatus{}
}
if err := c.Scan(context.Background(), fn); err == nil {
t.Errorf("scan shard error expected not nil. got %v", err)
}
if fnCallCounter != 0 {
t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter)
}
if val := ctr.counter; val != 0 {
t.Errorf("counter error expected %d, got %d", 0, val)
}
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "" {
t.Errorf("checkout error expected %s, got %s", "", val)
}
}
func TestScanShard(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
@ -89,11 +210,11 @@ func TestScanShard(t *testing.T) {
func TestScanShard_StopScan(t *testing.T) {
var records = []*kinesis.Record{
&kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&kinesis.Record{
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
@ -167,6 +288,7 @@ type kinesisClientMock struct {
kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error)
}
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
@ -177,6 +299,10 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput)
return c.getShardIteratorMock(in)
}
func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return c.describeStreamMock(in)
}
// implementation of checkpoint
type fakeCheckpoint struct {
cache map[string]string