Break down the big function and Add tests for Scan (#65)
This commit is contained in:
parent
23811ec99a
commit
d3b76346f5
2 changed files with 172 additions and 25 deletions
59
consumer.go
59
consumer.go
|
|
@ -79,7 +79,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
|
|
||||||
// default client if none provided
|
// default client if none provided
|
||||||
if c.client == nil {
|
if c.client == nil {
|
||||||
c.client = kinesis.New(session.New(aws.NewConfig()))
|
newSession, err := session.NewSession(aws.NewConfig())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.client = kinesis.New(newSession)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
|
|
@ -161,7 +165,10 @@ func (c *Consumer) ScanShard(
|
||||||
|
|
||||||
c.logger.Log("scanning", shardID, lastSeqNum)
|
c.logger.Log("scanning", shardID, lastSeqNum)
|
||||||
|
|
||||||
// scan pages of shard
|
return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|
@ -181,28 +188,17 @@ func (c *Consumer) ScanShard(
|
||||||
|
|
||||||
// loop records of page
|
// loop records of page
|
||||||
for _, r := range resp.Records {
|
for _, r := range resp.Records {
|
||||||
status := fn(r)
|
isScanStopped, err := c.handleRecord(shardID, r, fn)
|
||||||
|
if err != nil {
|
||||||
if !status.SkipCheckpoint {
|
|
||||||
lastSeqNum = *r.SequenceNumber
|
|
||||||
|
|
||||||
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
if isScanStopped {
|
||||||
|
|
||||||
if err := status.Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.counter.Add("records", 1)
|
|
||||||
|
|
||||||
if status.StopScan {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
lastSeqNum = *r.SequenceNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
|
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
shardIterator = resp.NextShardIterator
|
shardIterator = resp.NextShardIterator
|
||||||
|
|
@ -210,6 +206,31 @@ func (c *Consumer) ScanShard(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
|
||||||
|
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) {
|
||||||
|
status := fn(r)
|
||||||
|
|
||||||
|
if !status.SkipCheckpoint {
|
||||||
|
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := status.Error; err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.counter.Add("records", 1)
|
||||||
|
|
||||||
|
if status.StopScan {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||||
resp, err := c.client.DescribeStream(
|
resp, err := c.client.DescribeStream(
|
||||||
&kinesis.DescribeStreamInput{
|
&kinesis.DescribeStreamInput{
|
||||||
|
|
@ -220,7 +241,7 @@ func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||||
return nil, fmt.Errorf("describe stream error: %v", err)
|
return nil, fmt.Errorf("describe stream error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ss := []string{}
|
var ss []string
|
||||||
for _, shard := range resp.StreamDescription.Shards {
|
for _, shard := range resp.StreamDescription.Shards {
|
||||||
ss = append(ss, *shard.ShardId)
|
ss = append(ss, *shard.ShardId)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
138
consumer_test.go
138
consumer_test.go
|
|
@ -18,13 +18,134 @@ func TestNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanShard(t *testing.T) {
|
func TestConsumer_Scan(t *testing.T) {
|
||||||
var records = []*kinesis.Record{
|
records := []*kinesis.Record{
|
||||||
&kinesis.Record{
|
{
|
||||||
Data: []byte("firstData"),
|
Data: []byte("firstData"),
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
SequenceNumber: aws.String("firstSeqNum"),
|
||||||
},
|
},
|
||||||
&kinesis.Record{
|
{
|
||||||
|
Data: []byte("lastData"),
|
||||||
|
SequenceNumber: aws.String("lastSeqNum"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := &kinesisClientMock{
|
||||||
|
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: records,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
||||||
|
return &kinesis.DescribeStreamOutput{
|
||||||
|
StreamDescription: &kinesis.StreamDescription{
|
||||||
|
Shards: []*kinesis.Shard{
|
||||||
|
{ShardId: aws.String("myShard")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
cp = &fakeCheckpoint{cache: map[string]string{}}
|
||||||
|
ctr = &fakeCounter{}
|
||||||
|
)
|
||||||
|
|
||||||
|
c, err := New("myStreamName",
|
||||||
|
WithClient(client),
|
||||||
|
WithCounter(ctr),
|
||||||
|
WithCheckpoint(cp),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resultData string
|
||||||
|
var fnCallCounter int
|
||||||
|
var fn = func(r *Record) ScanStatus {
|
||||||
|
fnCallCounter++
|
||||||
|
resultData += string(r.Data)
|
||||||
|
return ScanStatus{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Scan(context.Background(), fn); err != nil {
|
||||||
|
t.Errorf("scan shard error expected nil. got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultData != "firstDatalastData" {
|
||||||
|
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
|
||||||
|
}
|
||||||
|
if fnCallCounter != 2 {
|
||||||
|
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
||||||
|
}
|
||||||
|
if val := ctr.counter; val != 2 {
|
||||||
|
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := cp.Get("myStreamName", "myShard")
|
||||||
|
if err != nil && val != "lastSeqNum" {
|
||||||
|
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
|
||||||
|
client := &kinesisClientMock{
|
||||||
|
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
||||||
|
return &kinesis.DescribeStreamOutput{
|
||||||
|
StreamDescription: &kinesis.StreamDescription{
|
||||||
|
Shards: make([]*kinesis.Shard, 0),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
cp = &fakeCheckpoint{cache: map[string]string{}}
|
||||||
|
ctr = &fakeCounter{}
|
||||||
|
)
|
||||||
|
|
||||||
|
c, err := New("myStreamName",
|
||||||
|
WithClient(client),
|
||||||
|
WithCounter(ctr),
|
||||||
|
WithCheckpoint(cp),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fnCallCounter int
|
||||||
|
var fn = func(r *Record) ScanStatus {
|
||||||
|
fnCallCounter++
|
||||||
|
return ScanStatus{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Scan(context.Background(), fn); err == nil {
|
||||||
|
t.Errorf("scan shard error expected not nil. got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fnCallCounter != 0 {
|
||||||
|
t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter)
|
||||||
|
}
|
||||||
|
if val := ctr.counter; val != 0 {
|
||||||
|
t.Errorf("counter error expected %d, got %d", 0, val)
|
||||||
|
}
|
||||||
|
val, err := cp.Get("myStreamName", "myShard")
|
||||||
|
if err != nil && val != "" {
|
||||||
|
t.Errorf("checkout error expected %s, got %s", "", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanShard(t *testing.T) {
|
||||||
|
var records = []*kinesis.Record{
|
||||||
|
{
|
||||||
|
Data: []byte("firstData"),
|
||||||
|
SequenceNumber: aws.String("firstSeqNum"),
|
||||||
|
},
|
||||||
|
{
|
||||||
Data: []byte("lastData"),
|
Data: []byte("lastData"),
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
SequenceNumber: aws.String("lastSeqNum"),
|
||||||
},
|
},
|
||||||
|
|
@ -89,11 +210,11 @@ func TestScanShard(t *testing.T) {
|
||||||
|
|
||||||
func TestScanShard_StopScan(t *testing.T) {
|
func TestScanShard_StopScan(t *testing.T) {
|
||||||
var records = []*kinesis.Record{
|
var records = []*kinesis.Record{
|
||||||
&kinesis.Record{
|
{
|
||||||
Data: []byte("firstData"),
|
Data: []byte("firstData"),
|
||||||
SequenceNumber: aws.String("firstSeqNum"),
|
SequenceNumber: aws.String("firstSeqNum"),
|
||||||
},
|
},
|
||||||
&kinesis.Record{
|
{
|
||||||
Data: []byte("lastData"),
|
Data: []byte("lastData"),
|
||||||
SequenceNumber: aws.String("lastSeqNum"),
|
SequenceNumber: aws.String("lastSeqNum"),
|
||||||
},
|
},
|
||||||
|
|
@ -167,6 +288,7 @@ type kinesisClientMock struct {
|
||||||
kinesisiface.KinesisAPI
|
kinesisiface.KinesisAPI
|
||||||
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
|
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
|
||||||
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
|
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
|
||||||
|
describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
|
@ -177,6 +299,10 @@ func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput)
|
||||||
return c.getShardIteratorMock(in)
|
return c.getShardIteratorMock(in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
||||||
|
return c.describeStreamMock(in)
|
||||||
|
}
|
||||||
|
|
||||||
// implementation of checkpoint
|
// implementation of checkpoint
|
||||||
type fakeCheckpoint struct {
|
type fakeCheckpoint struct {
|
||||||
cache map[string]string
|
cache map[string]string
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue