diff --git a/consumer.go b/consumer.go index cfb7c0b..53b9cdf 100644 --- a/consumer.go +++ b/consumer.go @@ -79,7 +79,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) { // default client if none provided if c.client == nil { - c.client = kinesis.New(session.New(aws.NewConfig())) + newSession, err := session.NewSession(aws.NewConfig()) + if err != nil { + return nil, err + } + c.client = kinesis.New(newSession) } return c, nil @@ -161,7 +165,10 @@ func (c *Consumer) ScanShard( c.logger.Log("scanning", shardID, lastSeqNum) - // scan pages of shard + return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn) +} + +func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error { for { select { case <-ctx.Done(): @@ -181,28 +188,17 @@ func (c *Consumer) ScanShard( // loop records of page for _, r := range resp.Records { - status := fn(r) - - if !status.SkipCheckpoint { - lastSeqNum = *r.SequenceNumber - - if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil { - return err - } - } - - if err := status.Error; err != nil { + isScanStopped, err := c.handleRecord(shardID, r, fn) + if err != nil { return err } - - c.counter.Add("records", 1) - - if status.StopScan { + if isScanStopped { return nil } + lastSeqNum = *r.SequenceNumber } - if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator { + if isShardClosed(resp.NextShardIterator, shardIterator) { return nil } shardIterator = resp.NextShardIterator @@ -210,6 +206,31 @@ func (c *Consumer) ScanShard( } } +func isShardClosed(nextShardIterator, currentShardIterator *string) bool { + return nextShardIterator == nil || currentShardIterator == nextShardIterator +} + +func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) { + status := fn(r) + + if !status.SkipCheckpoint { + if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { + return false, err + } + } + + if err := status.Error; err != nil { + return false, err + } + + c.counter.Add("records", 1) + + if status.StopScan { + return true, nil + } + return false, nil +} + func (c *Consumer) getShardIDs(streamName string) ([]string, error) { resp, err := c.client.DescribeStream( &kinesis.DescribeStreamInput{ @@ -220,7 +241,7 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) { return nil, fmt.Errorf("describe stream error: %v", err) } - ss := []string{} + var ss []string for _, shard := range resp.StreamDescription.Shards { ss = append(ss, *shard.ShardId) } diff --git a/consumer_test.go b/consumer_test.go index ce066d7..3f17373 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -18,13 +18,134 @@ func TestNew(t *testing.T) { } } -func TestScanShard(t *testing.T) { - var records = []*kinesis.Record{ - &kinesis.Record{ +func TestConsumer_Scan(t *testing.T) { + records := []*kinesis.Record{ + { Data: []byte("firstData"), SequenceNumber: aws.String("firstSeqNum"), }, - &kinesis.Record{ + { + Data: []byte("lastData"), + SequenceNumber: aws.String("lastSeqNum"), + }, + } + client := &kinesisClientMock{ + getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + return &kinesis.GetShardIteratorOutput{ + ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), + }, nil + }, + getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + return &kinesis.GetRecordsOutput{ + NextShardIterator: nil, + Records: records, + }, nil + }, + describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + Shards: []*kinesis.Shard{ + {ShardId: aws.String("myShard")}, + }, + }, + }, nil + }, + } + var ( + cp = &fakeCheckpoint{cache: map[string]string{}} + ctr = &fakeCounter{} + ) + + c, err := New("myStreamName", + WithClient(client), + WithCounter(ctr), + WithCheckpoint(cp), + ) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var resultData string + var fnCallCounter int + var fn = func(r *Record) ScanStatus { + fnCallCounter++ + resultData += string(r.Data) + return ScanStatus{} + } + + if err := c.Scan(context.Background(), fn); err != nil { + t.Errorf("scan shard error expected nil. got %v", err) + } + + if resultData != "firstDatalastData" { + t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData) + } + if fnCallCounter != 2 { + t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter) + } + if val := ctr.counter; val != 2 { + t.Errorf("counter error expected %d, got %d", 2, val) + } + + val, err := cp.Get("myStreamName", "myShard") + if err != nil && val != "lastSeqNum" { + t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val) + } +} + +func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { + client := &kinesisClientMock{ + describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + Shards: make([]*kinesis.Shard, 0), + }, + }, nil + }, + } + var ( + cp = &fakeCheckpoint{cache: map[string]string{}} + ctr = &fakeCounter{} + ) + + c, err := New("myStreamName", + WithClient(client), + WithCounter(ctr), + WithCheckpoint(cp), + ) + if err != nil { + t.Fatalf("new consumer error: %v", err) + } + + var fnCallCounter int + var fn = func(r *Record) ScanStatus { + fnCallCounter++ + return ScanStatus{} + } + + if err := c.Scan(context.Background(), fn); err == nil { + t.Errorf("scan shard error expected not nil. got %v", err) + } + + if fnCallCounter != 0 { + t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter) + } + if val := ctr.counter; val != 0 { + t.Errorf("counter error expected %d, got %d", 0, val) + } + val, err := cp.Get("myStreamName", "myShard") + if err != nil && val != "" { + t.Errorf("checkout error expected %s, got %s", "", val) + } +} + +func TestScanShard(t *testing.T) { + var records = []*kinesis.Record{ + { + Data: []byte("firstData"), + SequenceNumber: aws.String("firstSeqNum"), + }, + { Data: []byte("lastData"), SequenceNumber: aws.String("lastSeqNum"), }, @@ -89,11 +210,11 @@ func TestScanShard(t *testing.T) { func TestScanShard_StopScan(t *testing.T) { var records = []*kinesis.Record{ - &kinesis.Record{ + { Data: []byte("firstData"), SequenceNumber: aws.String("firstSeqNum"), }, - &kinesis.Record{ + { Data: []byte("lastData"), SequenceNumber: aws.String("lastSeqNum"), }, @@ -167,6 +288,7 @@ type kinesisClientMock struct { kinesisiface.KinesisAPI getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) + describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) } func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { @@ -177,6 +299,10 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) return c.getShardIteratorMock(in) } +func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return c.describeStreamMock(in) +} + // implementation of checkpoint type fakeCheckpoint struct { cache map[string]string