Introduce Client Interface

Testing the components of the consumer where proving difficult because
the consumer code was so tightly coupled with the Kinesis client
library.

* Extract the concept of Client interface
* Create default client w/ kinesis connection
* Test with fake client to avoid round trip to kinesis
This commit is contained in:
Harlow Ward 2017-11-26 16:00:11 -08:00
parent 058f383e30
commit b875bb56e7
3 changed files with 309 additions and 113 deletions

123
client.go Normal file
View file

@ -0,0 +1,123 @@
package consumer
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
)
// NewClient returns a new client with kinesis client
func NewClient() *client {
svc := kinesis.New(session.New(aws.NewConfig()))
return &client{svc}
}
// Client acts as wrapper around Kinesis client
type client struct {
svc *kinesis.Kinesis
}
// GetShardIDs returns shard ids in a given stream
func (c *client) GetShardIDs(streamName string) ([]string, error) {
resp, err := c.svc.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
ss := []string{}
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}
return ss, nil
}
// GetRecords returns a chan Record from a Shard of the Stream
func (c *client) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
shardIterator, err := c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
return nil, nil, fmt.Errorf("get shard iterator error: %v", err)
}
var (
recc = make(chan *Record, 10000)
errc = make(chan error, 1)
)
go func() {
defer func() {
close(recc)
close(errc)
}()
for {
select {
case <-ctx.Done():
return
default:
resp, err := c.svc.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil {
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
continue
}
for _, r := range resp.Records {
select {
case <-ctx.Done():
return
case recc <- r:
lastSeqNum = *r.SequenceNumber
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
} else {
shardIterator = resp.NextShardIterator
}
}
}
}()
return recc, errc, nil
}
func (c *client) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.svc.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
}

View file

@ -7,13 +7,18 @@ import (
"log" "log"
"sync" "sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis"
) )
// Record is an alias of record returned from kinesis library
type Record = kinesis.Record type Record = kinesis.Record
// Client interface is used for interacting with kinesis stream
type Client interface {
GetShardIDs(string) ([]string, error)
GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error)
}
// Counter interface is used for exposing basic metrics from the scanner // Counter interface is used for exposing basic metrics from the scanner
type Counter interface { type Counter interface {
Add(string, int64) Add(string, int64)
@ -61,39 +66,44 @@ func WithCounter(counter Counter) Option {
} }
} }
// WithClient overrides the default client
func WithClient(client Client) Option {
return func(c *Consumer) error {
c.client = client
return nil
}
}
// New creates a kinesis consumer with default settings. Use Option to override // New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes. // any of the optional attributes.
func New(stream string, opts ...Option) (*Consumer, error) { func New(streamName string, opts ...Option) (*Consumer, error) {
if stream == "" { if streamName == "" {
return nil, fmt.Errorf("must provide stream name") return nil, fmt.Errorf("must provide stream name")
} }
// new consumer with no-op checkpoint, counter, and logger
c := &Consumer{ c := &Consumer{
streamName: stream, streamName: streamName,
checkpoint: &noopCheckpoint{}, checkpoint: &noopCheckpoint{},
counter: &noopCounter{}, counter: &noopCounter{},
logger: log.New(ioutil.Discard, "", log.LstdFlags), logger: log.New(ioutil.Discard, "", log.LstdFlags),
client: NewClient(),
} }
// set options // override defaults
for _, opt := range opts { for _, opt := range opts {
if err := opt(c); err != nil { if err := opt(c); err != nil {
return nil, err return nil, err
} }
} }
// provide a default kinesis client
if c.client == nil {
c.client = kinesis.New(session.New(aws.NewConfig()))
}
return c, nil return c, nil
} }
// Consumer wraps the interaction with the Kinesis stream // Consumer wraps the interaction with the Kinesis stream
type Consumer struct { type Consumer struct {
streamName string streamName string
client *kinesis.Kinesis client Client
logger *log.Logger logger *log.Logger
checkpoint Checkpoint checkpoint Checkpoint
counter Counter counter Counter
@ -101,32 +111,27 @@ type Consumer struct {
// Scan scans each of the shards of the stream, calls the callback // Scan scans each of the shards of the stream, calls the callback
// func with each of the kinesis records. // func with each of the kinesis records.
func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) error { func (c *Consumer) Scan(ctx context.Context, fn func(*Record) bool) error {
ctx, cancel := context.WithCancel(ctx) shardIDs, err := c.client.GetShardIDs(c.streamName)
defer cancel()
// grab the stream details
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(c.streamName),
},
)
if err != nil { if err != nil {
return fmt.Errorf("describe stream error: %v", err) return fmt.Errorf("get shards error: %v", err)
} }
if len(resp.StreamDescription.Shards) == 0 { if len(shardIDs) == 0 {
return fmt.Errorf("no shards available") return fmt.Errorf("no shards available")
} }
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var ( var (
wg sync.WaitGroup wg sync.WaitGroup
errc = make(chan error, 1) errc = make(chan error, 1)
) )
wg.Add(len(resp.StreamDescription.Shards)) wg.Add(len(shardIDs))
// launch goroutine to process each of the shards // process each shard in goroutine
for _, shard := range resp.StreamDescription.Shards { for _, shardID := range shardIDs {
go func(shardID string) { go func(shardID string) {
defer wg.Done() defer wg.Done()
@ -139,9 +144,8 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
} }
} }
c.logger.Println("exiting", shardID)
cancel() cancel()
}(*shard.ShardId) }(shardID)
} }
wg.Wait() wg.Wait()
@ -152,100 +156,34 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
// Note: Returning `false` from the callback func will end the scan. // Note: Returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Record) bool) (err error) {
c.logger.Println("scanning", shardID)
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil { if err != nil {
return fmt.Errorf("get checkpoint error: %v", err) return fmt.Errorf("get checkpoint error: %v", err)
} }
shardIterator, err := c.getShardIterator(shardID, lastSeqNum) c.logger.Println("scanning", shardID, lastSeqNum)
// get records
recc, errc, err := c.client.GetRecords(ctx, c.streamName, shardID, lastSeqNum)
if err != nil { if err != nil {
return fmt.Errorf("get shard iterator error: %v", err) return fmt.Errorf("get records error: %v", err)
} }
for { // loop records
select { for r := range recc {
case <-ctx.Done(): if ok := fn(r); !ok {
return nil break
default: }
resp, err := c.client.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil { c.counter.Add("records", 1)
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
continue
}
if len(resp.Records) > 0 { err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber)
for _, r := range resp.Records { if err != nil {
select { return fmt.Errorf("set checkpoint error: %v", err)
case <-ctx.Done():
return nil
default:
lastSeqNum = *r.SequenceNumber
c.counter.Add("records", 1)
if ok := fn(r); !ok {
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
return fmt.Errorf("set checkpoint error: %v", err)
}
return nil
}
}
}
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
return fmt.Errorf("set checkpoint error: %v", err)
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
} else {
shardIterator = resp.NextShardIterator
}
} }
} }
}
c.logger.Println("exiting", shardID)
func (c *Consumer) setCheckpoint(shardID, lastSeqNum string) error { return <-errc
err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum)
if err != nil {
return err
}
c.logger.Println("checkpoint", shardID)
c.counter.Add("checkpoints", 1)
return nil
}
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(c.streamName),
}
if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.client.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
} }

135
consumer_test.go Normal file
View file

@ -0,0 +1,135 @@
package consumer
import (
"context"
"fmt"
"io/ioutil"
"log"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
)
func TestNew(t *testing.T) {
_, err := New("myStreamName")
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
}
func TestScanShard(t *testing.T) {
var (
ckp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
client = newFakeClient(
&Record{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&Record{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
)
)
c := &Consumer{
streamName: "myStreamName",
client: client,
checkpoint: ckp,
counter: ctr,
logger: log.New(ioutil.Discard, "", log.LstdFlags),
}
// callback fn simply appends the record data to result string
var (
resultData string
fn = func(r *Record) bool {
resultData += string(r.Data)
return true
}
)
// scan shard
err := c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
// increments counter
if val := ctr.counter; val != 2 {
t.Fatalf("counter error expected %d, got %d", 2, val)
}
// sets checkpoint
val, err := ckp.Get("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
}
// calls callback func
if resultData != "firstDatalastData" {
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", val)
}
}
func newFakeClient(rs ...*Record) *fakeClient {
fc := &fakeClient{
recc: make(chan *Record, len(rs)),
errc: make(chan error),
}
for _, r := range rs {
fc.recc <- r
}
close(fc.errc)
close(fc.recc)
return fc
}
type fakeClient struct {
shardIDs []string
recc chan *Record
errc chan error
}
func (fc *fakeClient) GetShardIDs(string) ([]string, error) {
return fc.shardIDs, nil
}
func (fc *fakeClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
return fc.recc, fc.errc, nil
}
type fakeCheckpoint struct {
cache map[string]string
mu sync.Mutex
}
func (fc *fakeCheckpoint) Set(streamName, shardID, sequenceNumber string) error {
fc.mu.Lock()
defer fc.mu.Unlock()
key := fmt.Sprintf("%s-%s", streamName, shardID)
fc.cache[key] = sequenceNumber
return nil
}
func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) {
fc.mu.Lock()
defer fc.mu.Unlock()
key := fmt.Sprintf("%s-%s", streamName, shardID)
return fc.cache[key], nil
}
type fakeCounter struct {
counter int64
}
func (fc *fakeCounter) Add(streamName string, count int64) {
fc.counter += count
}