diff --git a/README.md b/README.md index 0c686b9..038ec58 100644 --- a/README.md +++ b/README.md @@ -17,26 +17,37 @@ Get the package source: The consumer leverages a handler func that accepts a Kinesis record. The `Scan` method will consume all shards concurrently and call the callback func as it receives records from the stream. ```go -import consumer "github.com/harlow/kinesis-consumer" +import( + // ... + consumer "github.com/harlow/kinesis-consumer" + checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis" +) func main() { log.SetHandler(text.New(os.Stderr)) log.SetLevel(log.DebugLevel) var ( - app = flag.String("app", "", "App name") // name of consumer group + app = flag.String("app", "", "App name") stream = flag.String("stream", "", "Stream name") ) flag.Parse() - c, err := consumer.New(*app, *stream) + // new checkpoint + ck, err := checkpoint.New(*app, *stream) + if err != nil { + log.Fatalf("checkpoint error: %v", err) + } + + // new consumer + c, err := consumer.New(ck, *app, *stream) if err != nil { log.Fatalf("consumer error: %v", err) } - err = c.Scan(context.TODO(), func(r *kinesis.Record) bool { + // scan stream + err = c.Scan(context.TODO(), func(r *consumer.Record) bool { fmt.Println(string(r.Data)) - return true // continue scanning }) if err != nil { @@ -48,17 +59,55 @@ func main() { } ``` -### Configuration +### Checkpoint -The consumer requires the following config: +To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. -* App Name (used for checkpoints) -* Stream Name (kinesis stream name) +This will allow consumers to re-launch and pick up at the position in the stream where they left off. -It also accepts the following optional overrides: +The uniq identifier for a consumer is `[appName, streamName, shardID]` + +kinesis-checkpoints + +There are two types of checkpoints: + +### Redis + +The Redis checkpoint requries App Name, and Stream Name: + +```go +import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis" + +// redis checkpoint +ck, err := checkpoint.New(appName, streamName) +if err != nil { + log.Fatalf("new checkpoint error: %v", err) +} +``` + +### DynamoDB + +The DynamoDB checkpoint requires Table Name, App Name, and Stream Name: + +```go +import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb" + +// ddb checkpoint +ck, err := checkpoint.New(tableName, appName, streamName) +if err != nil { + log.Fatalf("new checkpoint error: %v", err) +} +``` + +To leverage the DDB checkpoint we'll also need to create a table: + +screen shot 2017-11-20 at 9 16 14 am + +### Options + +The consumer allows the following optional overrides: * Kinesis Client -* Checkpoint Storage * Logger ```go @@ -67,46 +116,12 @@ svc := kinesis.New(session.New(aws.NewConfig())) // new consumer with custom client c, err := consumer.New( - appName, + consumer, streamName, consumer.WithClient(svc), ) ``` -### Checkpoint Storage - -To record the progress of the consumer in the stream we store the last sequence number the consumer has read from a particular shard. This will allow consumers to re-launch and pick up at the position in the stream where they left off. - -kinesis-checkpoints - - -The default checkpoint uses Redis on localhost; to set a custom Redis URL use ENV vars: - -``` -REDIS_URL=redis.yoursite.com:6379 -``` - -To leverage DynamoDB as the backend for checkpoint we'll need a new table: - -screen shot 2017-11-20 at 9 16 14 am - -Then override the checkpoint config option: - -```go -// ddb checkpoint -ck, err := checkpoint.New(tableName, appName, streamName) -if err != nil { - log.Fatalf("new checkpoint error: %v", err) -} - -// consumer with checkpoint -c, err := consumer.New( - appName, - streamName, - consumer.WithCheckpoint(ck), -) -``` - ### Logging [Apex Log](https://medium.com/@tjholowaychuk/apex-log-e8d9627f4a9a#.5x1uo1767) is used for logging Info. Override the logs format with other [Log Handlers](https://github.com/apex/log/tree/master/_examples). For example using the "json" log handler: diff --git a/checkpoint/ddb/ddb.go b/checkpoint/ddb/ddb.go index d7628d6..84014a0 100644 --- a/checkpoint/ddb/ddb.go +++ b/checkpoint/ddb/ddb.go @@ -78,6 +78,10 @@ func (c *Checkpoint) Get(shardID string) (string, error) { // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon failover, record processing is resumed from this point. func (c *Checkpoint) Set(shardID string, sequenceNumber string) error { + if sequenceNumber == "" { + return fmt.Errorf("sequence number should not be empty") + } + item, err := dynamodbattribute.MarshalMap(item{ ConsumerGroup: c.consumerGroupName(), ShardID: shardID, diff --git a/checkpoint/redis/redis.go b/checkpoint/redis/redis.go index c8a3b36..6842788 100644 --- a/checkpoint/redis/redis.go +++ b/checkpoint/redis/redis.go @@ -25,38 +25,39 @@ func New(appName, streamName string) (*Checkpoint, error) { } return &Checkpoint{ - AppName: appName, - StreamName: streamName, + appName: appName, + streamName: streamName, client: client, }, nil } // Checkpoint stores and retreives the last evaluated key from a DDB scan type Checkpoint struct { - AppName string - StreamName string - - client *redis.Client + appName string + streamName string + client *redis.Client } -// Get determines if a checkpoint for a particular Shard exists. -// Typically used to determine whether we should start processing the shard with -// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). +// Get fetches the checkpoint for a particular Shard. func (c *Checkpoint) Get(shardID string) (string, error) { - return c.client.Get(c.key(shardID)).Result() + val, _ := c.client.Get(c.key(shardID)).Result() + return val, nil } // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon failover, record processing is resumed from this point. func (c *Checkpoint) Set(shardID string, sequenceNumber string) error { + if sequenceNumber == "" { + return fmt.Errorf("sequence number should not be empty") + } err := c.client.Set(c.key(shardID), sequenceNumber, 0).Err() if err != nil { - return fmt.Errorf("redis checkpoint error: %v", err) + return err } return nil } // key generates a unique Redis key for storage of Checkpoint. func (c *Checkpoint) key(shardID string) string { - return fmt.Sprintf("%v:checkpoint:%v:%v", c.AppName, c.StreamName, shardID) + return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, c.streamName, shardID) } diff --git a/checkpoint/redis/redis_test.go b/checkpoint/redis/redis_test.go index 0cadec7..7c49190 100644 --- a/checkpoint/redis/redis_test.go +++ b/checkpoint/redis/redis_test.go @@ -12,33 +12,33 @@ func Test_CheckpointLifecycle(t *testing.T) { client := redis.NewClient(&redis.Options{Addr: defaultAddr}) c := &Checkpoint{ - AppName: "app", - StreamName: "stream", + appName: "app", + streamName: "stream", client: client, } // set checkpoint - c.SetCheckpoint("shard_id", "testSeqNum") - - // checkpoint exists - if val := c.CheckpointExists("shard_id"); val != true { - t.Fatalf("checkpoint exists expected true, got %t", val) - } + c.Set("shard_id", "testSeqNum") // get checkpoint - if val := c.SequenceNumber(); val != "testSeqNum" { + val, err := c.Get("shard_id") + if err != nil { + t.Fatalf("get checkpoint error: %v", err) + } + + if val != "testSeqNum" { t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val) } - client.Del("app:checkpoint:stream:shard_id") + client.Del(c.key("shard_id")) } func Test_key(t *testing.T) { client := redis.NewClient(&redis.Options{Addr: defaultAddr}) c := &Checkpoint{ - AppName: "app", - StreamName: "stream", + appName: "app", + streamName: "stream", client: client, } diff --git a/consumer.go b/consumer.go index 03f2deb..13f9114 100644 --- a/consumer.go +++ b/consumer.go @@ -10,20 +10,13 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/harlow/kinesis-consumer/checkpoint" - "github.com/harlow/kinesis-consumer/checkpoint/redis" ) +type Record = kinesis.Record + // Option is used to override defaults when creating a new Consumer type Option func(*Consumer) error -// WithClient the Kinesis client -func WithClient(client *kinesis.Kinesis) Option { - return func(c *Consumer) error { - c.svc = client - return nil - } -} - // WithCheckpoint overrides the default checkpoint func WithCheckpoint(checkpoint checkpoint.Checkpoint) Option { return func(c *Consumer) error { @@ -42,18 +35,23 @@ func WithLogger(logger log.Interface) Option { // New creates a kinesis consumer with default settings. Use Option to override // any of the optional attributes. -func New(appName, streamName string, opts ...Option) (*Consumer, error) { - if appName == "" { - return nil, fmt.Errorf("must provide app name to consumer") +func New(checkpoint checkpoint.Checkpoint, app, stream string, opts ...Option) (*Consumer, error) { + if checkpoint == nil { + return nil, fmt.Errorf("must provide checkpoint") } - if streamName == "" { - return nil, fmt.Errorf("must provide stream name to consumer") + if app == "" { + return nil, fmt.Errorf("must provide app name") + } + + if stream == "" { + return nil, fmt.Errorf("must provide stream name") } c := &Consumer{ - appName: appName, - streamName: streamName, + checkpoint: checkpoint, + appName: app, + streamName: stream, } // set options @@ -67,23 +65,14 @@ func New(appName, streamName string, opts ...Option) (*Consumer, error) { if c.logger == nil { c.logger = log.Log.WithFields(log.Fields{ "package": "kinesis-consumer", - "app": appName, - "stream": streamName, + "app": app, + "stream": stream, }) } // provide a default kinesis client - if c.svc == nil { - c.svc = kinesis.New(session.New(aws.NewConfig())) - } - - // provide default Redis checkpoint - if c.checkpoint == nil { - ck, err := redis.New(appName, streamName) - if err != nil { - return nil, err - } - c.checkpoint = ck + if c.client == nil { + c.client = kinesis.New(session.New(aws.NewConfig())) } return c, nil @@ -91,9 +80,9 @@ func New(appName, streamName string, opts ...Option) (*Consumer, error) { // Consumer wraps the interaction with the Kinesis stream type Consumer struct { - appName string + appName string streamName string - svc *kinesis.Kinesis + client *kinesis.Kinesis logger log.Interface checkpoint checkpoint.Checkpoint } @@ -105,7 +94,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro defer cancel() // grab the stream details - resp, err := c.svc.DescribeStream( + resp, err := c.client.DescribeStream( &kinesis.DescribeStreamInput{ StreamName: aws.String(c.streamName), }, @@ -134,12 +123,15 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro // for each record and checkpoints after each page is processed. // Note: returning `false` from the callback func will end the scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) { - var ( - logger = c.logger.WithFields(log.Fields{"shard": shardID}) - sequenceNumber string - ) + var logger = c.logger.WithFields(log.Fields{"shard": shardID}) - shardIterator, err := c.getShardIterator(shardID) + lastSeqNum, err := c.checkpoint.Get(shardID) + if err != nil { + logger.WithError(err).Error("get checkpoint") + return + } + + shardIterator, err := c.getShardIterator(shardID, lastSeqNum) if err != nil { logger.WithError(err).Error("getShardIterator") return @@ -153,14 +145,14 @@ loop: case <-ctx.Done(): break loop default: - resp, err := c.svc.GetRecords( + resp, err := c.client.GetRecords( &kinesis.GetRecordsInput{ ShardIterator: shardIterator, }, ) if err != nil { - shardIterator, err = c.getShardIterator(shardID) + shardIterator, err = c.getShardIterator(shardID, lastSeqNum) if err != nil { logger.WithError(err).Error("getShardIterator") break loop @@ -174,21 +166,21 @@ loop: case <-ctx.Done(): break loop default: - sequenceNumber = *r.SequenceNumber + lastSeqNum = *r.SequenceNumber if ok := fn(r); !ok { break loop } } } - logger.WithField("records", len(resp.Records)).Info("checkpoint") - if err := c.checkpoint.Set(shardID, sequenceNumber); err != nil { + logger.WithField("count", len(resp.Records)).Info("checkpoint") + if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil { c.logger.WithError(err).Error("set checkpoint error") } } if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator { - shardIterator, err = c.getShardIterator(shardID) + shardIterator, err = c.getShardIterator(shardID, lastSeqNum) if err != nil { logger.WithError(err).Error("getShardIterator") break loop @@ -199,32 +191,29 @@ loop: } } - if sequenceNumber != "" { - if err := c.checkpoint.Set(shardID, sequenceNumber); err != nil { - c.logger.WithError(err).Error("set checkpoint error") - } + if lastSeqNum == "" { + return + } + + if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil { + c.logger.WithError(err).Error("set checkpoint error") } } -func (c *Consumer) getShardIterator(shardID string) (*string, error) { +func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) { params := &kinesis.GetShardIteratorInput{ ShardId: aws.String(shardID), StreamName: aws.String(c.streamName), } - seqNum, err := c.checkpoint.Get(shardID) - if err != nil { - return nil, err - } - - if seqNum != "" { + if lastSeqNum != "" { params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER") - params.StartingSequenceNumber = aws.String(seqNum) + params.StartingSequenceNumber = aws.String(lastSeqNum) } else { params.ShardIteratorType = aws.String("TRIM_HORIZON") } - resp, err := c.svc.GetShardIterator(params) + resp, err := c.client.GetShardIterator(params) if err != nil { c.logger.WithError(err).Error("GetShardIterator") return nil, err diff --git a/examples/consumer/main.go b/examples/consumer/main.go index 8c52270..6613feb 100644 --- a/examples/consumer/main.go +++ b/examples/consumer/main.go @@ -8,8 +8,8 @@ import ( "github.com/apex/log" "github.com/apex/log/handlers/text" - "github.com/aws/aws-sdk-go/service/kinesis" consumer "github.com/harlow/kinesis-consumer" + checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis" ) func main() { @@ -22,12 +22,20 @@ func main() { ) flag.Parse() - c, err := consumer.New(*app, *stream) + // new checkpoint + ck, err := checkpoint.New(*app, *stream) + if err != nil { + log.Fatalf("checkpoint error: %v", err) + } + + // new consumer + c, err := consumer.New(ck, *app, *stream) if err != nil { log.Fatalf("consumer error: %v", err) } - err = c.Scan(context.TODO(), func(r *kinesis.Record) bool { + // scan stream + err = c.Scan(context.TODO(), func(r *consumer.Record) bool { fmt.Println(string(r.Data)) return true // continue scanning }) diff --git a/examples/producer/main.go b/examples/producer/main.go index 40d1591..f30efc6 100644 --- a/examples/producer/main.go +++ b/examples/producer/main.go @@ -18,7 +18,7 @@ func main() { log.SetHandler(text.New(os.Stderr)) log.SetLevel(log.DebugLevel) - var streamName = flag.String("s", "", "Stream name") + var streamName = flag.String("stream", "", "Stream name") flag.Parse() // download file with test data