Introduce Client Interface
Testing the components of the consumer where proving difficult because the consumer code was so tightly coupled with the Kinesis client library. * Extract the concept of Client interface * Create default client w/ kinesis connection * Test with fake client to avoid round trip to kinesis
This commit is contained in:
parent
058f383e30
commit
b875bb56e7
3 changed files with 309 additions and 113 deletions
123
client.go
Normal file
123
client.go
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
package consumer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewClient returns a new client with kinesis client
|
||||||
|
func NewClient() *client {
|
||||||
|
svc := kinesis.New(session.New(aws.NewConfig()))
|
||||||
|
return &client{svc}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client acts as wrapper around Kinesis client
|
||||||
|
type client struct {
|
||||||
|
svc *kinesis.Kinesis
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetShardIDs returns shard ids in a given stream
|
||||||
|
func (c *client) GetShardIDs(streamName string) ([]string, error) {
|
||||||
|
resp, err := c.svc.DescribeStream(
|
||||||
|
&kinesis.DescribeStreamInput{
|
||||||
|
StreamName: aws.String(streamName),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("describe stream error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ss := []string{}
|
||||||
|
for _, shard := range resp.StreamDescription.Shards {
|
||||||
|
ss = append(ss, *shard.ShardId)
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRecords returns a chan Record from a Shard of the Stream
|
||||||
|
func (c *client) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
|
||||||
|
shardIterator, err := c.getShardIterator(streamName, shardID, lastSeqNum)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("get shard iterator error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
recc = make(chan *Record, 10000)
|
||||||
|
errc = make(chan error, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
close(recc)
|
||||||
|
close(errc)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
resp, err := c.svc.GetRecords(
|
||||||
|
&kinesis.GetRecordsInput{
|
||||||
|
ShardIterator: shardIterator,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
|
||||||
|
if err != nil {
|
||||||
|
errc <- fmt.Errorf("get shard iterator error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range resp.Records {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case recc <- r:
|
||||||
|
lastSeqNum = *r.SequenceNumber
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
|
||||||
|
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
|
||||||
|
if err != nil {
|
||||||
|
errc <- fmt.Errorf("get shard iterator error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
shardIterator = resp.NextShardIterator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return recc, errc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
|
||||||
|
params := &kinesis.GetShardIteratorInput{
|
||||||
|
ShardId: aws.String(shardID),
|
||||||
|
StreamName: aws.String(streamName),
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastSeqNum != "" {
|
||||||
|
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
|
||||||
|
params.StartingSequenceNumber = aws.String(lastSeqNum)
|
||||||
|
} else {
|
||||||
|
params.ShardIteratorType = aws.String("TRIM_HORIZON")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.svc.GetShardIterator(params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.ShardIterator, nil
|
||||||
|
}
|
||||||
164
consumer.go
164
consumer.go
|
|
@ -7,13 +7,18 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Record is an alias of record returned from kinesis library
|
||||||
type Record = kinesis.Record
|
type Record = kinesis.Record
|
||||||
|
|
||||||
|
// Client interface is used for interacting with kinesis stream
|
||||||
|
type Client interface {
|
||||||
|
GetShardIDs(string) ([]string, error)
|
||||||
|
GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Counter interface is used for exposing basic metrics from the scanner
|
// Counter interface is used for exposing basic metrics from the scanner
|
||||||
type Counter interface {
|
type Counter interface {
|
||||||
Add(string, int64)
|
Add(string, int64)
|
||||||
|
|
@ -61,39 +66,44 @@ func WithCounter(counter Counter) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithClient overrides the default client
|
||||||
|
func WithClient(client Client) Option {
|
||||||
|
return func(c *Consumer) error {
|
||||||
|
c.client = client
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// New creates a kinesis consumer with default settings. Use Option to override
|
// New creates a kinesis consumer with default settings. Use Option to override
|
||||||
// any of the optional attributes.
|
// any of the optional attributes.
|
||||||
func New(stream string, opts ...Option) (*Consumer, error) {
|
func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
if stream == "" {
|
if streamName == "" {
|
||||||
return nil, fmt.Errorf("must provide stream name")
|
return nil, fmt.Errorf("must provide stream name")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// new consumer with no-op checkpoint, counter, and logger
|
||||||
c := &Consumer{
|
c := &Consumer{
|
||||||
streamName: stream,
|
streamName: streamName,
|
||||||
checkpoint: &noopCheckpoint{},
|
checkpoint: &noopCheckpoint{},
|
||||||
counter: &noopCounter{},
|
counter: &noopCounter{},
|
||||||
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
||||||
|
client: NewClient(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// set options
|
// override defaults
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if err := opt(c); err != nil {
|
if err := opt(c); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// provide a default kinesis client
|
|
||||||
if c.client == nil {
|
|
||||||
c.client = kinesis.New(session.New(aws.NewConfig()))
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consumer wraps the interaction with the Kinesis stream
|
// Consumer wraps the interaction with the Kinesis stream
|
||||||
type Consumer struct {
|
type Consumer struct {
|
||||||
streamName string
|
streamName string
|
||||||
client *kinesis.Kinesis
|
client Client
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
checkpoint Checkpoint
|
checkpoint Checkpoint
|
||||||
counter Counter
|
counter Counter
|
||||||
|
|
@ -101,32 +111,27 @@ type Consumer struct {
|
||||||
|
|
||||||
// Scan scans each of the shards of the stream, calls the callback
|
// Scan scans each of the shards of the stream, calls the callback
|
||||||
// func with each of the kinesis records.
|
// func with each of the kinesis records.
|
||||||
func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) error {
|
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) bool) error {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
shardIDs, err := c.client.GetShardIDs(c.streamName)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// grab the stream details
|
|
||||||
resp, err := c.client.DescribeStream(
|
|
||||||
&kinesis.DescribeStreamInput{
|
|
||||||
StreamName: aws.String(c.streamName),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("describe stream error: %v", err)
|
return fmt.Errorf("get shards error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(resp.StreamDescription.Shards) == 0 {
|
if len(shardIDs) == 0 {
|
||||||
return fmt.Errorf("no shards available")
|
return fmt.Errorf("no shards available")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
errc = make(chan error, 1)
|
errc = make(chan error, 1)
|
||||||
)
|
)
|
||||||
wg.Add(len(resp.StreamDescription.Shards))
|
wg.Add(len(shardIDs))
|
||||||
|
|
||||||
// launch goroutine to process each of the shards
|
// process each shard in goroutine
|
||||||
for _, shard := range resp.StreamDescription.Shards {
|
for _, shardID := range shardIDs {
|
||||||
go func(shardID string) {
|
go func(shardID string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
|
|
@ -139,9 +144,8 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Println("exiting", shardID)
|
|
||||||
cancel()
|
cancel()
|
||||||
}(*shard.ShardId)
|
}(shardID)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
@ -152,100 +156,34 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
|
||||||
// ScanShard loops over records on a specific shard, calls the callback func
|
// ScanShard loops over records on a specific shard, calls the callback func
|
||||||
// for each record and checkpoints the progress of scan.
|
// for each record and checkpoints the progress of scan.
|
||||||
// Note: Returning `false` from the callback func will end the scan.
|
// Note: Returning `false` from the callback func will end the scan.
|
||||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) error {
|
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Record) bool) (err error) {
|
||||||
c.logger.Println("scanning", shardID)
|
|
||||||
|
|
||||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get checkpoint error: %v", err)
|
return fmt.Errorf("get checkpoint error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
|
c.logger.Println("scanning", shardID, lastSeqNum)
|
||||||
|
|
||||||
|
// get records
|
||||||
|
recc, errc, err := c.client.GetRecords(ctx, c.streamName, shardID, lastSeqNum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get shard iterator error: %v", err)
|
return fmt.Errorf("get records error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
// loop records
|
||||||
select {
|
for r := range recc {
|
||||||
case <-ctx.Done():
|
if ok := fn(r); !ok {
|
||||||
return nil
|
break
|
||||||
default:
|
}
|
||||||
resp, err := c.client.GetRecords(
|
|
||||||
&kinesis.GetRecordsInput{
|
|
||||||
ShardIterator: shardIterator,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
c.counter.Add("records", 1)
|
||||||
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get shard iterator error: %v", err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Records) > 0 {
|
err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber)
|
||||||
for _, r := range resp.Records {
|
if err != nil {
|
||||||
select {
|
return fmt.Errorf("set checkpoint error: %v", err)
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
lastSeqNum = *r.SequenceNumber
|
|
||||||
c.counter.Add("records", 1)
|
|
||||||
|
|
||||||
if ok := fn(r); !ok {
|
|
||||||
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
|
|
||||||
return fmt.Errorf("set checkpoint error: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
|
|
||||||
return fmt.Errorf("set checkpoint error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
|
|
||||||
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get shard iterator error: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
shardIterator = resp.NextShardIterator
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
c.logger.Println("exiting", shardID)
|
||||||
func (c *Consumer) setCheckpoint(shardID, lastSeqNum string) error {
|
return <-errc
|
||||||
err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.logger.Println("checkpoint", shardID)
|
|
||||||
c.counter.Add("checkpoints", 1)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {
|
|
||||||
params := &kinesis.GetShardIteratorInput{
|
|
||||||
ShardId: aws.String(shardID),
|
|
||||||
StreamName: aws.String(c.streamName),
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastSeqNum != "" {
|
|
||||||
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
|
|
||||||
params.StartingSequenceNumber = aws.String(lastSeqNum)
|
|
||||||
} else {
|
|
||||||
params.ShardIteratorType = aws.String("TRIM_HORIZON")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.client.GetShardIterator(params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp.ShardIterator, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
135
consumer_test.go
Normal file
135
consumer_test.go
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
package consumer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
_, err := New("myStreamName")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanShard(t *testing.T) {
|
||||||
|
var (
|
||||||
|
ckp = &fakeCheckpoint{cache: map[string]string{}}
|
||||||
|
ctr = &fakeCounter{}
|
||||||
|
client = newFakeClient(
|
||||||
|
&Record{
|
||||||
|
Data: []byte("firstData"),
|
||||||
|
SequenceNumber: aws.String("firstSeqNum"),
|
||||||
|
},
|
||||||
|
&Record{
|
||||||
|
Data: []byte("lastData"),
|
||||||
|
SequenceNumber: aws.String("lastSeqNum"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
c := &Consumer{
|
||||||
|
streamName: "myStreamName",
|
||||||
|
client: client,
|
||||||
|
checkpoint: ckp,
|
||||||
|
counter: ctr,
|
||||||
|
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
||||||
|
}
|
||||||
|
|
||||||
|
// callback fn simply appends the record data to result string
|
||||||
|
var (
|
||||||
|
resultData string
|
||||||
|
fn = func(r *Record) bool {
|
||||||
|
resultData += string(r.Data)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// scan shard
|
||||||
|
err := c.ScanShard(context.Background(), "myShard", fn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("scan shard error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// increments counter
|
||||||
|
if val := ctr.counter; val != 2 {
|
||||||
|
t.Fatalf("counter error expected %d, got %d", 2, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sets checkpoint
|
||||||
|
val, err := ckp.Get("myStreamName", "myShard")
|
||||||
|
if err != nil && val != "lastSeqNum" {
|
||||||
|
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// calls callback func
|
||||||
|
if resultData != "firstDatalastData" {
|
||||||
|
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeClient(rs ...*Record) *fakeClient {
|
||||||
|
fc := &fakeClient{
|
||||||
|
recc: make(chan *Record, len(rs)),
|
||||||
|
errc: make(chan error),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rs {
|
||||||
|
fc.recc <- r
|
||||||
|
}
|
||||||
|
|
||||||
|
close(fc.errc)
|
||||||
|
close(fc.recc)
|
||||||
|
|
||||||
|
return fc
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeClient struct {
|
||||||
|
shardIDs []string
|
||||||
|
recc chan *Record
|
||||||
|
errc chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *fakeClient) GetShardIDs(string) ([]string, error) {
|
||||||
|
return fc.shardIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *fakeClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
|
||||||
|
return fc.recc, fc.errc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeCheckpoint struct {
|
||||||
|
cache map[string]string
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *fakeCheckpoint) Set(streamName, shardID, sequenceNumber string) error {
|
||||||
|
fc.mu.Lock()
|
||||||
|
defer fc.mu.Unlock()
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s-%s", streamName, shardID)
|
||||||
|
fc.cache[key] = sequenceNumber
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) {
|
||||||
|
fc.mu.Lock()
|
||||||
|
defer fc.mu.Unlock()
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s-%s", streamName, shardID)
|
||||||
|
return fc.cache[key], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeCounter struct {
|
||||||
|
counter int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *fakeCounter) Add(streamName string, count int64) {
|
||||||
|
fc.counter += count
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue