diff --git a/client.go b/client.go new file mode 100644 index 0000000..f46760b --- /dev/null +++ b/client.go @@ -0,0 +1,123 @@ +package consumer + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kinesis" +) + +// NewClient returns a new client with kinesis client +func NewClient() *client { + svc := kinesis.New(session.New(aws.NewConfig())) + return &client{svc} +} + +// Client acts as wrapper around Kinesis client +type client struct { + svc *kinesis.Kinesis +} + +// GetShardIDs returns shard ids in a given stream +func (c *client) GetShardIDs(streamName string) ([]string, error) { + resp, err := c.svc.DescribeStream( + &kinesis.DescribeStreamInput{ + StreamName: aws.String(streamName), + }, + ) + if err != nil { + return nil, fmt.Errorf("describe stream error: %v", err) + } + + ss := []string{} + for _, shard := range resp.StreamDescription.Shards { + ss = append(ss, *shard.ShardId) + } + return ss, nil +} + +// GetRecords returns a chan Record from a Shard of the Stream +func (c *client) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) { + shardIterator, err := c.getShardIterator(streamName, shardID, lastSeqNum) + if err != nil { + return nil, nil, fmt.Errorf("get shard iterator error: %v", err) + } + + var ( + recc = make(chan *Record, 10000) + errc = make(chan error, 1) + ) + + go func() { + defer func() { + close(recc) + close(errc) + }() + + for { + select { + case <-ctx.Done(): + return + default: + resp, err := c.svc.GetRecords( + &kinesis.GetRecordsInput{ + ShardIterator: shardIterator, + }, + ) + + if err != nil { + shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum) + if err != nil { + errc <- fmt.Errorf("get shard iterator error: %v", err) + return + } + continue + } + + for _, r := range resp.Records { + select { + case <-ctx.Done(): + return + case recc <- r: + lastSeqNum = *r.SequenceNumber + } + } + + if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator { + shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum) + if err != nil { + errc <- fmt.Errorf("get shard iterator error: %v", err) + return + } + } else { + shardIterator = resp.NextShardIterator + } + } + } + }() + + return recc, errc, nil +} + +func (c *client) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) { + params := &kinesis.GetShardIteratorInput{ + ShardId: aws.String(shardID), + StreamName: aws.String(streamName), + } + + if lastSeqNum != "" { + params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER") + params.StartingSequenceNumber = aws.String(lastSeqNum) + } else { + params.ShardIteratorType = aws.String("TRIM_HORIZON") + } + + resp, err := c.svc.GetShardIterator(params) + if err != nil { + return nil, err + } + + return resp.ShardIterator, nil +} diff --git a/consumer.go b/consumer.go index f662220..5c96424 100644 --- a/consumer.go +++ b/consumer.go @@ -7,13 +7,18 @@ import ( "log" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" ) +// Record is an alias of record returned from kinesis library type Record = kinesis.Record +// Client interface is used for interacting with kinesis stream +type Client interface { + GetShardIDs(string) ([]string, error) + GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) +} + // Counter interface is used for exposing basic metrics from the scanner type Counter interface { Add(string, int64) @@ -61,39 +66,44 @@ func WithCounter(counter Counter) Option { } } +// WithClient overrides the default client +func WithClient(client Client) Option { + return func(c *Consumer) error { + c.client = client + return nil + } +} + // New creates a kinesis consumer with default settings. Use Option to override // any of the optional attributes. -func New(stream string, opts ...Option) (*Consumer, error) { - if stream == "" { +func New(streamName string, opts ...Option) (*Consumer, error) { + if streamName == "" { return nil, fmt.Errorf("must provide stream name") } + // new consumer with no-op checkpoint, counter, and logger c := &Consumer{ - streamName: stream, + streamName: streamName, checkpoint: &noopCheckpoint{}, counter: &noopCounter{}, logger: log.New(ioutil.Discard, "", log.LstdFlags), + client: NewClient(), } - // set options + // override defaults for _, opt := range opts { if err := opt(c); err != nil { return nil, err } } - // provide a default kinesis client - if c.client == nil { - c.client = kinesis.New(session.New(aws.NewConfig())) - } - return c, nil } // Consumer wraps the interaction with the Kinesis stream type Consumer struct { streamName string - client *kinesis.Kinesis + client Client logger *log.Logger checkpoint Checkpoint counter Counter @@ -101,32 +111,27 @@ type Consumer struct { // Scan scans each of the shards of the stream, calls the callback // func with each of the kinesis records. -func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - // grab the stream details - resp, err := c.client.DescribeStream( - &kinesis.DescribeStreamInput{ - StreamName: aws.String(c.streamName), - }, - ) +func (c *Consumer) Scan(ctx context.Context, fn func(*Record) bool) error { + shardIDs, err := c.client.GetShardIDs(c.streamName) if err != nil { - return fmt.Errorf("describe stream error: %v", err) + return fmt.Errorf("get shards error: %v", err) } - if len(resp.StreamDescription.Shards) == 0 { + if len(shardIDs) == 0 { return fmt.Errorf("no shards available") } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + var ( wg sync.WaitGroup errc = make(chan error, 1) ) - wg.Add(len(resp.StreamDescription.Shards)) + wg.Add(len(shardIDs)) - // launch goroutine to process each of the shards - for _, shard := range resp.StreamDescription.Shards { + // process each shard in goroutine + for _, shardID := range shardIDs { go func(shardID string) { defer wg.Done() @@ -139,9 +144,8 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro } } - c.logger.Println("exiting", shardID) cancel() - }(*shard.ShardId) + }(shardID) } wg.Wait() @@ -152,100 +156,34 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints the progress of scan. // Note: Returning `false` from the callback func will end the scan. -func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) error { - c.logger.Println("scanning", shardID) - +func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Record) bool) (err error) { lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) if err != nil { return fmt.Errorf("get checkpoint error: %v", err) } - shardIterator, err := c.getShardIterator(shardID, lastSeqNum) + c.logger.Println("scanning", shardID, lastSeqNum) + + // get records + recc, errc, err := c.client.GetRecords(ctx, c.streamName, shardID, lastSeqNum) if err != nil { - return fmt.Errorf("get shard iterator error: %v", err) + return fmt.Errorf("get records error: %v", err) } - for { - select { - case <-ctx.Done(): - return nil - default: - resp, err := c.client.GetRecords( - &kinesis.GetRecordsInput{ - ShardIterator: shardIterator, - }, - ) + // loop records + for r := range recc { + if ok := fn(r); !ok { + break + } - if err != nil { - shardIterator, err = c.getShardIterator(shardID, lastSeqNum) - if err != nil { - return fmt.Errorf("get shard iterator error: %v", err) - } - continue - } + c.counter.Add("records", 1) - if len(resp.Records) > 0 { - for _, r := range resp.Records { - select { - case <-ctx.Done(): - return nil - default: - lastSeqNum = *r.SequenceNumber - c.counter.Add("records", 1) - - if ok := fn(r); !ok { - if err := c.setCheckpoint(shardID, lastSeqNum); err != nil { - return fmt.Errorf("set checkpoint error: %v", err) - } - return nil - } - } - } - - if err := c.setCheckpoint(shardID, lastSeqNum); err != nil { - return fmt.Errorf("set checkpoint error: %v", err) - } - } - - if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator { - shardIterator, err = c.getShardIterator(shardID, lastSeqNum) - if err != nil { - return fmt.Errorf("get shard iterator error: %v", err) - } - } else { - shardIterator = resp.NextShardIterator - } + err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber) + if err != nil { + return fmt.Errorf("set checkpoint error: %v", err) } } -} - -func (c *Consumer) setCheckpoint(shardID, lastSeqNum string) error { - err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum) - if err != nil { - return err - } - c.logger.Println("checkpoint", shardID) - c.counter.Add("checkpoints", 1) - return nil -} - -func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) { - params := &kinesis.GetShardIteratorInput{ - ShardId: aws.String(shardID), - StreamName: aws.String(c.streamName), - } - - if lastSeqNum != "" { - params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER") - params.StartingSequenceNumber = aws.String(lastSeqNum) - } else { - params.ShardIteratorType = aws.String("TRIM_HORIZON") - } - - resp, err := c.client.GetShardIterator(params) - if err != nil { - return nil, err - } - - return resp.ShardIterator, nil + + c.logger.Println("exiting", shardID) + return <-errc } diff --git a/consumer_test.go b/consumer_test.go new file mode 100644 index 0000000..d3feb2c --- /dev/null +++ b/consumer_test.go @@ -0,0 +1,135 @@ +package consumer + +import ( + "context" + "fmt" + "io/ioutil" + "log" + "sync" + "testing" + + "github.com/aws/aws-sdk-go/aws" +) + +func TestNew(t *testing.T) { + _, err := New("myStreamName") + if err != nil { + t.Fatalf("new consumer error: %v", err) + } +} + +func TestScanShard(t *testing.T) { + var ( + ckp = &fakeCheckpoint{cache: map[string]string{}} + ctr = &fakeCounter{} + client = newFakeClient( + &Record{ + Data: []byte("firstData"), + SequenceNumber: aws.String("firstSeqNum"), + }, + &Record{ + Data: []byte("lastData"), + SequenceNumber: aws.String("lastSeqNum"), + }, + ) + ) + + c := &Consumer{ + streamName: "myStreamName", + client: client, + checkpoint: ckp, + counter: ctr, + logger: log.New(ioutil.Discard, "", log.LstdFlags), + } + + // callback fn simply appends the record data to result string + var ( + resultData string + fn = func(r *Record) bool { + resultData += string(r.Data) + return true + } + ) + + // scan shard + err := c.ScanShard(context.Background(), "myShard", fn) + if err != nil { + t.Fatalf("scan shard error: %v", err) + } + + // increments counter + if val := ctr.counter; val != 2 { + t.Fatalf("counter error expected %d, got %d", 2, val) + } + + // sets checkpoint + val, err := ckp.Get("myStreamName", "myShard") + if err != nil && val != "lastSeqNum" { + t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val) + } + + // calls callback func + if resultData != "firstDatalastData" { + t.Fatalf("callback error expected %s, got %s", "firstDatalastData", val) + } +} + +func newFakeClient(rs ...*Record) *fakeClient { + fc := &fakeClient{ + recc: make(chan *Record, len(rs)), + errc: make(chan error), + } + + for _, r := range rs { + fc.recc <- r + } + + close(fc.errc) + close(fc.recc) + + return fc +} + +type fakeClient struct { + shardIDs []string + recc chan *Record + errc chan error +} + +func (fc *fakeClient) GetShardIDs(string) ([]string, error) { + return fc.shardIDs, nil +} + +func (fc *fakeClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) { + return fc.recc, fc.errc, nil +} + +type fakeCheckpoint struct { + cache map[string]string + mu sync.Mutex +} + +func (fc *fakeCheckpoint) Set(streamName, shardID, sequenceNumber string) error { + fc.mu.Lock() + defer fc.mu.Unlock() + + key := fmt.Sprintf("%s-%s", streamName, shardID) + fc.cache[key] = sequenceNumber + return nil +} + +func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) { + fc.mu.Lock() + defer fc.mu.Unlock() + + key := fmt.Sprintf("%s-%s", streamName, shardID) + return fc.cache[key], nil +} + +type fakeCounter struct { + counter int64 +} + +func (fc *fakeCounter) Add(streamName string, count int64) { + fc.counter += count +}