Remove the client wrapper (#58)

Having an additional Client has added some confusion (https://github.com/harlow/kinesis-consumer/issues/45) on how to provide a
custom kinesis client. Allowing `WithClient` to accept a Kinesis client
it cleans up the interface.

Major changes:

* Remove the Client wrapper; prefer using kinesis client directly
* Change `ScanError` to `ScanStatus` as the return value isn't necessarily an error

Note: these are breaking changes, if you need last stable release please see here: https://github.com/harlow/kinesis-consumer/releases/tag/v0.2.0
This commit is contained in:
Harlow Ward 2018-07-28 22:53:33 -07:00 committed by GitHub
parent 049445e259
commit fb98fbe244
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 358 additions and 527 deletions

View file

@ -4,25 +4,26 @@ All notable changes to this project will be documented in this file.
## [Unreleased (`master`)][unreleased] ## [Unreleased (`master`)][unreleased]
** Breaking changes to consumer library **
Major changes: Major changes:
* Use [functional options][options] for config * Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client.
* Remove intermediate batching of kinesis records * Rename `ScanError` to `ScanStatus` as it's not always an error.
* Call the callback func with each record
* Use dep for vendoring dependencies
* Add DDB as storage layer for checkpoints
Minor changes: Minor changes:
* remove unused buffer and emitter code * Update tests to use Kinesis mock
[unreleased]: https://github.com/harlow/kinesis-consumer/compare/v0.1.0...HEAD ## v0.2.0 - 2018-07-28
[options]: https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
This is the last stable release from which there is a separate Client. It has caused confusion and will be removed going forward.
https://github.com/harlow/kinesis-consumer/releases/tag/v0.2.0
## v0.1.0 - 2017-11-20 ## v0.1.0 - 2017-11-20
This is the last stable release of the consumer which aggregated records in `batch` before calling the callback func. This is the last stable release of the consumer which aggregated records in `batch` before calling the callback func.
https://github.com/harlow/kinesis-consumer/releases/tag/v0.1.0 https://github.com/harlow/kinesis-consumer/releases/tag/v0.1.0
[unreleased]: https://github.com/harlow/kinesis-consumer/compare/v0.2.0...HEAD
[options]: https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis

View file

@ -37,14 +37,14 @@ func main() {
log.Fatalf("consumer error: %v", err) log.Fatalf("consumer error: %v", err)
} }
// start // start scan
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanError { err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanError{ return consumer.ScanStatus{
StopScan: false, // true to stop scan StopScan: false, // true to stop scan
SkipCheckpoint: false, // true to skip checkpoint SkipCheckpoint: false, // true to skip checkpoint
} }
}) })
if err != nil { if err != nil {
log.Fatalf("scan error: %v", err) log.Fatalf("scan error: %v", err)
@ -55,6 +55,24 @@ func main() {
} }
``` ```
## Scan status
The scan func returns a `consumer.ScanStatus` the struct allows some basic flow control.
```go
// continue scanning
return consumer.ScanStatus{}
// continue scanning, skip saving checkpoint
return consumer.ScanStatus{SkipCheckpoint: true}
// stop scanning, return nil
return consumer.ScanStatus{StopScan: true}
// stop scanning, return error
return consumer.ScanStatus{Error: err}
```
## Checkpoint ## Checkpoint
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback. To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback.
@ -107,8 +125,9 @@ myDynamoDbClient := dynamodb.New(session.New(aws.NewConfig()))
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient)) ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient))
if err != nil { if err != nil {
log.Fatalf("new checkpoint error: %v", err) log.Fatalf("new checkpoint error: %v", err)
} }
// Or we can provide your own Retryer to customize what triggers a retry inside checkpoint // Or we can provide your own Retryer to customize what triggers a retry inside checkpoint
// See code in examples // See code in examples
// ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{})) // ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{}))
@ -133,7 +152,7 @@ import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/postgres"
// postgres checkpoint // postgres checkpoint
ck, err := checkpoint.New(app, table, connStr) ck, err := checkpoint.New(app, table, connStr)
if err != nil { if err != nil {
log.Fatalf("new checkpoint error: %v", err) log.Fatalf("new checkpoint error: %v", err)
} }
``` ```
@ -155,7 +174,7 @@ The table name has to be the same that you specify when creating the checkpoint.
The consumer allows the following optional overrides. The consumer allows the following optional overrides.
### Client ### Kinesis Client
Override the Kinesis client if there is any special config needed: Override the Kinesis client if there is any special config needed:
@ -189,6 +208,7 @@ The [expvar package](https://golang.org/pkg/expvar/) will display consumer count
``` ```
### Logging ### Logging
Logging supports the basic built-in logging library or use thrid party external one, so long as Logging supports the basic built-in logging library or use thrid party external one, so long as
it implements the Logger interface. it implements the Logger interface.
@ -197,12 +217,12 @@ For example, to use the builtin logging package, we wrap it with myLogger struct
``` ```
// A myLogger provides a minimalistic logger satisfying the Logger interface. // A myLogger provides a minimalistic logger satisfying the Logger interface.
type myLogger struct { type myLogger struct {
logger *log.Logger logger *log.Logger
} }
// Log logs the parameters to the stdlib logger. See log.Println. // Log logs the parameters to the stdlib logger. See log.Println.
func (l *myLogger) Log(args ...interface{}) { func (l *myLogger) Log(args ...interface{}) {
l.logger.Println(args...) l.logger.Println(args...)
} }
``` ```
@ -210,29 +230,32 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom
```go ```go
// logger // logger
log := &myLogger{ logger : log.New(os.Stdout, "consumer-example: ", log.LstdFlags),} log := &myLogger{
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
}
// consumer // consumer
c, err := consumer.New(streamName, consumer.WithLogger(logger)) c, err := consumer.New(streamName, consumer.WithLogger(logger))
``` ```
To use a more complicated logging library, e.g. apex log To use a more complicated logging library, e.g. apex log
``` ```
type myLogger struct { type myLogger struct {
logger *log.Logger logger *log.Logger
} }
func (l *myLogger) Log(args ...interface{}) { func (l *myLogger) Log(args ...interface{}) {
l.logger.Infof("producer", args...) l.logger.Infof("producer", args...)
} }
func main() { func main() {
log := &myLogger{
log := &myLogger{ logger: alog.Logger{
logger: alog.Logger{ Handler: text.New(os.Stderr),
Handler: text.New(os.Stderr), Level: alog.DebugLevel,
Level: alog.DebugLevel, },
}, }
}
``` ```
## Contributing ## Contributing

12
checkpoint.go Normal file
View file

@ -0,0 +1,12 @@
package consumer
// Checkpoint interface used track consumer progress in the stream
type Checkpoint interface {
Get(streamName, shardID string) (string, error)
Set(streamName, shardID, sequenceNumber string) error
}
type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string, string) error { return nil }
func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }

152
client.go
View file

@ -1,152 +0,0 @@
package consumer
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
// ClientOption is used to override defaults when creating a KinesisClient
type ClientOption func(*KinesisClient)
// WithKinesis overrides the default Kinesis client
func WithKinesis(svc kinesisiface.KinesisAPI) ClientOption {
return func(kc *KinesisClient) {
kc.svc = svc
}
}
// WithStartFromLatest will make sure the client start consuming
// events starting from the most recent event in kinesis. This
// option discards the checkpoints.
func WithStartFromLatest() ClientOption {
return func(kc *KinesisClient) {
kc.fromLatest = true
}
}
// NewKinesisClient returns client to interface with Kinesis stream
func NewKinesisClient(opts ...ClientOption) (*KinesisClient, error) {
kc := &KinesisClient{}
for _, opt := range opts {
opt(kc)
}
newSession, err := session.NewSession(aws.NewConfig())
if err != nil {
return nil, err
}
if kc.svc == nil {
kc.svc = kinesis.New(newSession)
}
return kc, nil
}
// KinesisClient acts as wrapper around Kinesis client
type KinesisClient struct {
svc kinesisiface.KinesisAPI
fromLatest bool
}
// GetShardIDs returns shard ids in a given stream
func (c *KinesisClient) GetShardIDs(streamName string) ([]string, error) {
resp, err := c.svc.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
var ss []string
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}
return ss, nil
}
// GetRecords returns a chan Record from a Shard of the Stream
func (c *KinesisClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
shardIterator, err := c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
return nil, nil, fmt.Errorf("get shard iterator error: %v", err)
}
var (
recc = make(chan *Record, 10000)
errc = make(chan error, 1)
)
go func() {
defer func() {
close(recc)
close(errc)
}()
for {
select {
case <-ctx.Done():
return
default:
resp, err := c.svc.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil {
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
continue
}
for _, r := range resp.Records {
select {
case <-ctx.Done():
return
case recc <- r:
lastSeqNum = *r.SequenceNumber
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
shardIterator = resp.NextShardIterator
}
}
}()
return recc, errc, nil
}
func (c *KinesisClient) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if c.fromLatest {
params.ShardIteratorType = aws.String("LATEST")
} else if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.svc.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
}

View file

@ -1,150 +0,0 @@
package consumer_test
import (
"testing"
"context"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
"github.com/harlow/kinesis-consumer"
)
func TestKinesisClient_GetRecords_SuccessfullyRun(t *testing.T) {
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*kinesis.Record, 0),
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("New kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
recordsChan, errorsChan, err := c.GetRecords(ctx, "myStream", "shardId-000000000000", "")
if recordsChan == nil {
t.Errorf("records channel expected not nil, got %v", recordsChan)
}
if errorsChan == nil {
t.Errorf("errors channel expected not nil, got %v", recordsChan)
}
if err != nil {
t.Errorf("error expected nil, got %v", err)
}
cancelFunc()
}
func TestKinesisClient_GetRecords_SuccessfullyRetrievesThreeRecordsAtOnce(t *testing.T) {
expectedResults := []*kinesis.Record{
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576195"),
},
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576196"),
},
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576197"),
}}
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: expectedResults,
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("new kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
recordsChan, _, err := c.GetRecords(ctx, "TestStream", "shardId-000000000000", "")
if recordsChan == nil {
t.Fatalf("records channel expected not nil, got %v", recordsChan)
}
if err != nil {
t.Fatalf("error expected nil, got %v", err)
}
var results []*consumer.Record
results = append(results, <-recordsChan, <-recordsChan, <-recordsChan)
if len(results) != 3 {
t.Errorf("number of records expected 3, got %v", len(results))
}
for i, r := range results {
if r != expectedResults[i] {
t.Errorf("record expected %v, got %v", expectedResults[i], r)
}
}
cancelFunc()
}
func TestKinesisClient_GetRecords_ShardIsClosed(t *testing.T) {
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*consumer.Record, 0),
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("new kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
_, errorsChan, err := c.GetRecords(ctx, "TestStream", "shardId-000000000000", "")
if errorsChan == nil {
t.Fatalf("errors channel expected equals not nil, got %v", errorsChan)
}
if err != nil {
t.Fatalf("error expected, got %v", err)
}
err = <-errorsChan
if err == nil {
t.Errorf("error expected, got %v", err)
}
cancelFunc()
}
type kinesisClientMock struct {
kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
}
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return c.getRecordsMock(in)
}
func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return c.getShardIteratorMock(in)
}

View file

@ -3,83 +3,58 @@ package consumer
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"log"
"sync" "sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
) )
// ScanError signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanError struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// Record is an alias of record returned from kinesis library // Record is an alias of record returned from kinesis library
type Record = kinesis.Record type Record = kinesis.Record
// Client interface is used for interacting with kinesis stream
type Client interface {
GetShardIDs(string) ([]string, error)
GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error)
}
// Counter interface is used for exposing basic metrics from the scanner
type Counter interface {
Add(string, int64)
}
type noopCounter struct{}
func (n noopCounter) Add(string, int64) {}
// Checkpoint interface used track consumer progress in the stream
type Checkpoint interface {
Get(streamName, shardID string) (string, error)
Set(streamName, shardID, sequenceNumber string) error
}
type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string, string) error { return nil }
func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }
// Option is used to override defaults when creating a new Consumer // Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error type Option func(*Consumer)
// WithCheckpoint overrides the default checkpoint // WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint Checkpoint) Option { func WithCheckpoint(checkpoint Checkpoint) Option {
return func(c *Consumer) error { return func(c *Consumer) {
c.checkpoint = checkpoint c.checkpoint = checkpoint
return nil
} }
} }
// WithLogger overrides the default logger // WithLogger overrides the default logger
func WithLogger(logger Logger) Option { func WithLogger(logger Logger) Option {
return func(c *Consumer) error { return func(c *Consumer) {
c.logger = logger c.logger = logger
return nil
} }
} }
// WithCounter overrides the default counter // WithCounter overrides the default counter
func WithCounter(counter Counter) Option { func WithCounter(counter Counter) Option {
return func(c *Consumer) error { return func(c *Consumer) {
c.counter = counter c.counter = counter
return nil
} }
} }
// WithClient overrides the default client // WithClient overrides the default client
func WithClient(client Client) Option { func WithClient(client kinesisiface.KinesisAPI) Option {
return func(c *Consumer) error { return func(c *Consumer) {
c.client = client c.client = client
return nil
} }
} }
// ScanStatus signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanStatus struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// New creates a kinesis consumer with default settings. Use Option to override // New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes. // any of the optional attributes.
func New(streamName string, opts ...Option) (*Consumer, error) { func New(streamName string, opts ...Option) (*Consumer, error) {
@ -87,25 +62,24 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
return nil, fmt.Errorf("must provide stream name") return nil, fmt.Errorf("must provide stream name")
} }
kc, err := NewKinesisClient()
if err != nil {
return nil, err
}
// new consumer with no-op checkpoint, counter, and logger // new consumer with no-op checkpoint, counter, and logger
c := &Consumer{ c := &Consumer{
streamName: streamName, streamName: streamName,
checkpoint: &noopCheckpoint{}, checkpoint: &noopCheckpoint{},
counter: &noopCounter{}, counter: &noopCounter{},
logger: NewDefaultLogger(), logger: &noopLogger{
client: kc, logger: log.New(ioutil.Discard, "", log.LstdFlags),
},
} }
// override defaults // override defaults
for _, opt := range opts { for _, opt := range opts {
if err := opt(c); err != nil { opt(c)
return nil, err }
}
// default client if none provided
if c.client == nil {
c.client = kinesis.New(session.New(aws.NewConfig()))
} }
return c, nil return c, nil
@ -114,7 +88,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
// Consumer wraps the interaction with the Kinesis stream // Consumer wraps the interaction with the Kinesis stream
type Consumer struct { type Consumer struct {
streamName string streamName string
client Client client kinesisiface.KinesisAPI
logger Logger logger Logger
checkpoint Checkpoint checkpoint Checkpoint
counter Counter counter Counter
@ -122,8 +96,12 @@ type Consumer struct {
// Scan scans each of the shards of the stream, calls the callback // Scan scans each of the shards of the stream, calls the callback
// func with each of the kinesis records. // func with each of the kinesis records.
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error { func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error {
shardIDs, err := c.client.GetShardIDs(c.streamName) ctx, cancel := context.WithCancel(ctx)
defer cancel()
// get shard ids
shardIDs, err := c.getShardIDs(c.streamName)
if err != nil { if err != nil {
return fmt.Errorf("get shards error: %v", err) return fmt.Errorf("get shards error: %v", err)
} }
@ -132,16 +110,13 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error {
return fmt.Errorf("no shards available") return fmt.Errorf("no shards available")
} }
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var ( var (
wg sync.WaitGroup wg sync.WaitGroup
errc = make(chan error, 1) errc = make(chan error, 1)
) )
wg.Add(len(shardIDs)) wg.Add(len(shardIDs))
// process each shard in goroutine // process each shard in a separate goroutine
for _, shardID := range shardIDs { for _, shardID := range shardIDs {
go func(shardID string) { go func(shardID string) {
defer wg.Done() defer wg.Done()
@ -161,47 +136,113 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error {
wg.Wait() wg.Wait()
close(errc) close(errc)
return <-errc return <-errc
} }
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
// Note: Returning `false` from the callback func will end the scan. func (c *Consumer) ScanShard(
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Record) ScanError) (err error) { ctx context.Context,
shardID string,
fn func(*Record) ScanStatus,
) error {
// get checkpoint
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil { if err != nil {
return fmt.Errorf("get checkpoint error: %v", err) return fmt.Errorf("get checkpoint error: %v", err)
} }
// get shard iterator
shardIterator, err := c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
c.logger.Log("scanning", shardID, lastSeqNum) c.logger.Log("scanning", shardID, lastSeqNum)
// get records // scan pages of shard
recc, errc, err := c.client.GetRecords(ctx, c.streamName, shardID, lastSeqNum) for {
if err != nil { select {
return fmt.Errorf("get records error: %v", err) case <-ctx.Done():
} return nil
// loop records default:
for r := range recc { resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{
scanError := fn(r) ShardIterator: shardIterator,
})
// Skip invalid state
if scanError.StopScan && scanError.SkipCheckpoint {
continue
}
if scanError.StopScan {
break
}
if !scanError.SkipCheckpoint {
c.counter.Add("records", 1)
err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber)
if err != nil { if err != nil {
return fmt.Errorf("set checkpoint error: %v", err) shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
continue
} }
// loop records of page
for _, r := range resp.Records {
status := fn(r)
if !status.SkipCheckpoint {
lastSeqNum = *r.SequenceNumber
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
return err
}
}
if err := status.Error; err != nil {
return err
}
c.counter.Add("records", 1)
if status.StopScan {
return nil
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
return nil
}
shardIterator = resp.NextShardIterator
} }
} }
}
c.logger.Log("exiting", shardID)
return <-errc func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
ss := []string{}
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}
return ss, nil
}
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.client.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
} }

View file

@ -3,12 +3,14 @@ package consumer
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"log"
"sync" "sync"
"testing" "testing"
"errors"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
@ -20,41 +22,36 @@ func TestNew(t *testing.T) {
func TestScanShard(t *testing.T) { func TestScanShard(t *testing.T) {
var ( var (
ckp = &fakeCheckpoint{cache: map[string]string{}} resultData string
ctr = &fakeCounter{} ckp = &fakeCheckpoint{cache: map[string]string{}}
client = newFakeClient( ctr = &fakeCounter{}
&Record{ mockSvc = &mockKinesisClient{}
Data: []byte("firstData"), logger = &noopLogger{
SequenceNumber: aws.String("firstSeqNum"), logger: log.New(ioutil.Discard, "", log.LstdFlags),
}, }
&Record{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
)
) )
c := &Consumer{ c := &Consumer{
streamName: "myStreamName", streamName: "myStreamName",
client: client, client: mockSvc,
checkpoint: ckp, checkpoint: ckp,
counter: ctr, counter: ctr,
logger: NewDefaultLogger(), logger: logger,
} }
var recordNum = 0
// callback fn simply appends the record data to result string // callback fn simply appends the record data to result string
var ( var fn = func(r *Record) ScanStatus {
resultData string resultData += string(r.Data)
fn = func(r *Record) ScanError { recordNum++
resultData += string(r.Data) stopScan := recordNum == 2
err := errors.New("some error happened")
return ScanError{ return ScanStatus{
Error: err, StopScan: stopScan,
StopScan: false, SkipCheckpoint: false,
SkipCheckpoint: false,
}
} }
) }
// scan shard // scan shard
err := c.ScanShard(context.Background(), "myShard", fn) err := c.ScanShard(context.Background(), "myShard", fn)
@ -79,34 +76,30 @@ func TestScanShard(t *testing.T) {
} }
} }
func newFakeClient(rs ...*Record) *fakeClient { type mockKinesisClient struct {
fc := &fakeClient{ kinesisiface.KinesisAPI
recc: make(chan *Record, len(rs)),
errc: make(chan error),
}
for _, r := range rs {
fc.recc <- r
}
close(fc.errc)
close(fc.recc)
return fc
} }
type fakeClient struct { func (m *mockKinesisClient) GetRecords(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
shardIDs []string
recc chan *Record return &kinesis.GetRecordsOutput{
errc chan error Records: []*kinesis.Record{
&kinesis.Record{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&kinesis.Record{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
},
}, nil
} }
func (fc *fakeClient) GetShardIDs(string) ([]string, error) { func (m *mockKinesisClient) GetShardIterator(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return fc.shardIDs, nil return &kinesis.GetShardIteratorOutput{
} ShardIterator: aws.String("myshard"),
}, nil
func (fc *fakeClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
return fc.recc, fc.errc, nil
} }
type fakeCheckpoint struct { type fakeCheckpoint struct {

10
counter.go Normal file
View file

@ -0,0 +1,10 @@
package consumer
// Counter interface is used for exposing basic metrics from the scanner
type Counter interface {
Add(string, int64)
}
type noopCounter struct{}
func (n noopCounter) Add(string, int64) {}

View file

@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"errors"
"expvar" "expvar"
"flag" "flag"
"fmt" "fmt"
@ -29,7 +28,7 @@ import (
func init() { func init() {
sock, err := net.Listen("tcp", "localhost:8080") sock, err := net.Listen("tcp", "localhost:8080")
if err != nil { if err != nil {
log.Println("net listen error: %v", err) log.Printf("net listen error: %v", err)
} }
go func() { go func() {
fmt.Println("Metrics available at http://localhost:8080/debug/vars") fmt.Println("Metrics available at http://localhost:8080/debug/vars")
@ -66,24 +65,26 @@ func main() {
// Following will overwrite the default dynamodb client // Following will overwrite the default dynamodb client
// Older versions of aws sdk does not picking up aws config properly. // Older versions of aws sdk does not picking up aws config properly.
// You probably need to update aws sdk verison. Tested the following with 1.13.59 // You probably need to update aws sdk verison. Tested the following with 1.13.59
myDynamoDbClient := dynamodb.New(session.New(aws.NewConfig()), &aws.Config{ myDynamoDbClient := dynamodb.New(
Region: aws.String("us-west-2"), session.New(aws.NewConfig()), &aws.Config{
}) Region: aws.String("us-west-2"),
},
)
// ddb checkpoint // ddb checkpoint
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{})) ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{}))
if err != nil { if err != nil {
log.Log("checkpoint error: %v", err) log.Log("checkpoint error: %v", err)
} }
var (
counter = expvar.NewMap("counters") var counter = expvar.NewMap("counters")
)
// The following 2 lines will overwrite the default kinesis client // The following 2 lines will overwrite the default kinesis client
myKinesisClient := kinesis.New(session.New(aws.NewConfig()), &aws.Config{ ksis := kinesis.New(
Region: aws.String("us-west-2"), session.New(aws.NewConfig()), &aws.Config{
}) Region: aws.String("us-west-2"),
newKclient := consumer.NewKinesisClient(consumer.WithKinesis(myKinesisClient)) },
)
// consumer // consumer
c, err := consumer.New( c, err := consumer.New(
@ -91,7 +92,7 @@ func main() {
consumer.WithCheckpoint(ck), consumer.WithCheckpoint(ck),
consumer.WithLogger(log), consumer.WithLogger(log),
consumer.WithCounter(counter), consumer.WithCounter(counter),
consumer.WithClient(newKclient), consumer.WithClient(ksis),
) )
if err != nil { if err != nil {
log.Log("consumer error: %v", err) log.Log("consumer error: %v", err)
@ -110,15 +111,11 @@ func main() {
}() }()
// scan stream // scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanError { err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
err := errors.New("some error happened")
// continue scanning // continue scanning
return consumer.ScanError{ return consumer.ScanStatus{}
Error: err,
StopScan: true,
SkipCheckpoint: false,
}
}) })
if err != nil { if err != nil {
log.Log("scan error: %v", err) log.Log("scan error: %v", err)
@ -129,10 +126,12 @@ func main() {
} }
} }
// MyRetryer used for checkpointing
type MyRetryer struct { type MyRetryer struct {
checkpoint.Retryer checkpoint.Retryer
} }
// ShouldRetry implements custom logic for when a checkpont should retry
func (r *MyRetryer) ShouldRetry(err error) bool { func (r *MyRetryer) ShouldRetry(err error) bool {
if awsErr, ok := err.(awserr.Error); ok { if awsErr, ok := err.(awserr.Error); ok {
switch awsErr.Code() { switch awsErr.Code() {

View file

@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"errors"
"expvar" "expvar"
"flag" "flag"
"fmt" "fmt"
@ -29,24 +28,19 @@ func main() {
log.Fatalf("checkpoint error: %v", err) log.Fatalf("checkpoint error: %v", err)
} }
var ( var counter = expvar.NewMap("counters")
counter = expvar.NewMap("counters")
)
newKclient := consumer.NewKinesisClient()
// consumer // consumer
c, err := consumer.New( c, err := consumer.New(
*stream, *stream,
consumer.WithCheckpoint(ck), consumer.WithCheckpoint(ck),
consumer.WithCounter(counter), consumer.WithCounter(counter),
consumer.WithClient(newKclient),
) )
if err != nil { if err != nil {
log.Fatalf("consumer error: %v", err) log.Fatalf("consumer error: %v", err)
} }
// use cancel \func to signal shutdown // use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
// trap SIGINT, wait to trigger shutdown // trap SIGINT, wait to trigger shutdown
@ -59,15 +53,11 @@ func main() {
}() }()
// scan stream // scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanError { err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
err := errors.New("some error happened")
// continue scanning // continue scanning
return consumer.ScanError{ return consumer.ScanStatus{}
Error: err,
StopScan: false,
SkipCheckpoint: false,
}
}) })
if err != nil { if err != nil {

View file

@ -0,0 +1,18 @@
# Consumer
Read records from the Kinesis stream
### Environment Variables
Export the required environment vars for connecting to the Kinesis stream and Redis for checkpoint:
```
export AWS_ACCESS_KEY=
export AWS_REGION=
export AWS_SECRET_KEY=
export REDIS_URL=
```
### Run the consumer
$ go run main.go --app appName --stream streamName

View file

@ -0,0 +1,58 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"os/signal"
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
)
func main() {
var (
app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
)
flag.Parse()
// redis checkpoint
ck, err := checkpoint.New(*app)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
// consumer
c, err := consumer.New(
*stream, consumer.WithCheckpoint(ck),
)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
go func() {
<-signals
cancel()
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanStatus{}
})
if err != nil {
log.Fatalf("scan error: %v", err)
}
}

View file

@ -4,24 +4,20 @@ import (
"bufio" "bufio"
"flag" "flag"
"fmt" "fmt"
"log"
"os" "os"
"time" "time"
"github.com/apex/log"
"github.com/apex/log/handlers/text"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis"
) )
var svc = kinesis.New(session.New(), &aws.Config{ var svc = kinesis.New(session.New(), &aws.Config{
Region: aws.String("us-west-2"), Region: aws.String("us-west-1"),
}) })
func main() { func main() {
log.SetHandler(text.New(os.Stderr))
log.SetLevel(log.DebugLevel)
var streamName = flag.String("stream", "", "Stream name") var streamName = flag.String("stream", "", "Stream name")
flag.Parse() flag.Parse()
@ -60,7 +56,7 @@ func putRecords(streamName *string, records []*kinesis.PutRecordsRequestEntry) {
Records: records, Records: records,
}) })
if err != nil { if err != nil {
log.Fatal("error putting records") log.Fatalf("error putting records: %v", err)
} }
fmt.Print(".") fmt.Print(".")
} }

View file

@ -1,7 +1,6 @@
package consumer package consumer
import ( import (
"io/ioutil"
"log" "log"
) )
@ -12,19 +11,12 @@ type Logger interface {
type LoggerFunc func(...interface{}) type LoggerFunc func(...interface{})
// NewDefaultLogger returns a Logger which discards messages. // noopLogger implements logger interface with discard
func NewDefaultLogger() Logger { type noopLogger struct {
return &defaultLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
}
}
// A defaultLogger provides a logging instance when none is provided.
type defaultLogger struct {
logger *log.Logger logger *log.Logger
} }
// Log using stdlib logger. See log.Println. // Log using stdlib logger. See log.Println.
func (l defaultLogger) Log(args ...interface{}) { func (l noopLogger) Log(args ...interface{}) {
l.logger.Println(args...) l.logger.Println(args...)
} }