Make the Checkpoint a required input for Consumer

The Checkpoint functionality is an important part of the library and
previously it wasn't obvious that the Consumer was defaulting to Redis
for this functionality.

* Add Checkpoint as required param for new consumer
This commit is contained in:
Harlow Ward 2017-11-21 08:58:16 -08:00
parent 154595b9a3
commit 4d6a85e901
7 changed files with 147 additions and 130 deletions

107
README.md
View file

@ -17,26 +17,37 @@ Get the package source:
The consumer leverages a handler func that accepts a Kinesis record. The `Scan` method will consume all shards concurrently and call the callback func as it receives records from the stream.
```go
import consumer "github.com/harlow/kinesis-consumer"
import(
// ...
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
)
func main() {
log.SetHandler(text.New(os.Stderr))
log.SetLevel(log.DebugLevel)
var (
app = flag.String("app", "", "App name") // name of consumer group
app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
)
flag.Parse()
c, err := consumer.New(*app, *stream)
// new checkpoint
ck, err := checkpoint.New(*app, *stream)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
// new consumer
c, err := consumer.New(ck, *app, *stream)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
err = c.Scan(context.TODO(), func(r *kinesis.Record) bool {
// scan stream
err = c.Scan(context.TODO(), func(r *consumer.Record) bool {
fmt.Println(string(r.Data))
return true // continue scanning
})
if err != nil {
@ -48,17 +59,55 @@ func main() {
}
```
### Configuration
### Checkpoint
The consumer requires the following config:
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard.
* App Name (used for checkpoints)
* Stream Name (kinesis stream name)
This will allow consumers to re-launch and pick up at the position in the stream where they left off.
It also accepts the following optional overrides:
The uniq identifier for a consumer is `[appName, streamName, shardID]`
<img width="687" alt="kinesis-checkpoints" src="https://user-images.githubusercontent.com/739782/33036582-b6f3c4b4-cde3-11e7-9334-c4bfbe34d984.png">
There are two types of checkpoints:
### Redis
The Redis checkpoint requries App Name, and Stream Name:
```go
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
// redis checkpoint
ck, err := checkpoint.New(appName, streamName)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
```
### DynamoDB
The DynamoDB checkpoint requires Table Name, App Name, and Stream Name:
```go
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb"
// ddb checkpoint
ck, err := checkpoint.New(tableName, appName, streamName)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
```
To leverage the DDB checkpoint we'll also need to create a table:
<img width="659" alt="screen shot 2017-11-20 at 9 16 14 am" src="https://user-images.githubusercontent.com/739782/33033316-db85f848-cdd8-11e7-941a-0a87d8ace479.png">
### Options
The consumer allows the following optional overrides:
* Kinesis Client
* Checkpoint Storage
* Logger
```go
@ -67,46 +116,12 @@ svc := kinesis.New(session.New(aws.NewConfig()))
// new consumer with custom client
c, err := consumer.New(
appName,
consumer,
streamName,
consumer.WithClient(svc),
)
```
### Checkpoint Storage
To record the progress of the consumer in the stream we store the last sequence number the consumer has read from a particular shard. This will allow consumers to re-launch and pick up at the position in the stream where they left off.
<img width="687" alt="kinesis-checkpoints" src="https://user-images.githubusercontent.com/739782/33036582-b6f3c4b4-cde3-11e7-9334-c4bfbe34d984.png">
The default checkpoint uses Redis on localhost; to set a custom Redis URL use ENV vars:
```
REDIS_URL=redis.yoursite.com:6379
```
To leverage DynamoDB as the backend for checkpoint we'll need a new table:
<img width="659" alt="screen shot 2017-11-20 at 9 16 14 am" src="https://user-images.githubusercontent.com/739782/33033316-db85f848-cdd8-11e7-941a-0a87d8ace479.png">
Then override the checkpoint config option:
```go
// ddb checkpoint
ck, err := checkpoint.New(tableName, appName, streamName)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
// consumer with checkpoint
c, err := consumer.New(
appName,
streamName,
consumer.WithCheckpoint(ck),
)
```
### Logging
[Apex Log](https://medium.com/@tjholowaychuk/apex-log-e8d9627f4a9a#.5x1uo1767) is used for logging Info. Override the logs format with other [Log Handlers](https://github.com/apex/log/tree/master/_examples). For example using the "json" log handler:

View file

@ -78,6 +78,10 @@ func (c *Checkpoint) Get(shardID string) (string, error) {
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
item, err := dynamodbattribute.MarshalMap(item{
ConsumerGroup: c.consumerGroupName(),
ShardID: shardID,

View file

@ -25,38 +25,39 @@ func New(appName, streamName string) (*Checkpoint, error) {
}
return &Checkpoint{
AppName: appName,
StreamName: streamName,
appName: appName,
streamName: streamName,
client: client,
}, nil
}
// Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct {
AppName string
StreamName string
appName string
streamName string
client *redis.Client
}
// Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
// Get fetches the checkpoint for a particular Shard.
func (c *Checkpoint) Get(shardID string) (string, error) {
return c.client.Get(c.key(shardID)).Result()
val, _ := c.client.Get(c.key(shardID)).Result()
return val, nil
}
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
err := c.client.Set(c.key(shardID), sequenceNumber, 0).Err()
if err != nil {
return fmt.Errorf("redis checkpoint error: %v", err)
return err
}
return nil
}
// key generates a unique Redis key for storage of Checkpoint.
func (c *Checkpoint) key(shardID string) string {
return fmt.Sprintf("%v:checkpoint:%v:%v", c.AppName, c.StreamName, shardID)
return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, c.streamName, shardID)
}

View file

@ -12,33 +12,33 @@ func Test_CheckpointLifecycle(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{
AppName: "app",
StreamName: "stream",
appName: "app",
streamName: "stream",
client: client,
}
// set checkpoint
c.SetCheckpoint("shard_id", "testSeqNum")
// checkpoint exists
if val := c.CheckpointExists("shard_id"); val != true {
t.Fatalf("checkpoint exists expected true, got %t", val)
}
c.Set("shard_id", "testSeqNum")
// get checkpoint
if val := c.SequenceNumber(); val != "testSeqNum" {
val, err := c.Get("shard_id")
if err != nil {
t.Fatalf("get checkpoint error: %v", err)
}
if val != "testSeqNum" {
t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val)
}
client.Del("app:checkpoint:stream:shard_id")
client.Del(c.key("shard_id"))
}
func Test_key(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{
AppName: "app",
StreamName: "stream",
appName: "app",
streamName: "stream",
client: client,
}

View file

@ -10,20 +10,13 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/harlow/kinesis-consumer/checkpoint"
"github.com/harlow/kinesis-consumer/checkpoint/redis"
)
type Record = kinesis.Record
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error
// WithClient the Kinesis client
func WithClient(client *kinesis.Kinesis) Option {
return func(c *Consumer) error {
c.svc = client
return nil
}
}
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint checkpoint.Checkpoint) Option {
return func(c *Consumer) error {
@ -42,18 +35,23 @@ func WithLogger(logger log.Interface) Option {
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(appName, streamName string, opts ...Option) (*Consumer, error) {
if appName == "" {
return nil, fmt.Errorf("must provide app name to consumer")
func New(checkpoint checkpoint.Checkpoint, app, stream string, opts ...Option) (*Consumer, error) {
if checkpoint == nil {
return nil, fmt.Errorf("must provide checkpoint")
}
if streamName == "" {
return nil, fmt.Errorf("must provide stream name to consumer")
if app == "" {
return nil, fmt.Errorf("must provide app name")
}
if stream == "" {
return nil, fmt.Errorf("must provide stream name")
}
c := &Consumer{
appName: appName,
streamName: streamName,
checkpoint: checkpoint,
appName: app,
streamName: stream,
}
// set options
@ -67,23 +65,14 @@ func New(appName, streamName string, opts ...Option) (*Consumer, error) {
if c.logger == nil {
c.logger = log.Log.WithFields(log.Fields{
"package": "kinesis-consumer",
"app": appName,
"stream": streamName,
"app": app,
"stream": stream,
})
}
// provide a default kinesis client
if c.svc == nil {
c.svc = kinesis.New(session.New(aws.NewConfig()))
}
// provide default Redis checkpoint
if c.checkpoint == nil {
ck, err := redis.New(appName, streamName)
if err != nil {
return nil, err
}
c.checkpoint = ck
if c.client == nil {
c.client = kinesis.New(session.New(aws.NewConfig()))
}
return c, nil
@ -93,7 +82,7 @@ func New(appName, streamName string, opts ...Option) (*Consumer, error) {
type Consumer struct {
appName string
streamName string
svc *kinesis.Kinesis
client *kinesis.Kinesis
logger log.Interface
checkpoint checkpoint.Checkpoint
}
@ -105,7 +94,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
defer cancel()
// grab the stream details
resp, err := c.svc.DescribeStream(
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(c.streamName),
},
@ -134,12 +123,15 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
// for each record and checkpoints after each page is processed.
// Note: returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
var (
logger = c.logger.WithFields(log.Fields{"shard": shardID})
sequenceNumber string
)
var logger = c.logger.WithFields(log.Fields{"shard": shardID})
shardIterator, err := c.getShardIterator(shardID)
lastSeqNum, err := c.checkpoint.Get(shardID)
if err != nil {
logger.WithError(err).Error("get checkpoint")
return
}
shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
if err != nil {
logger.WithError(err).Error("getShardIterator")
return
@ -153,14 +145,14 @@ loop:
case <-ctx.Done():
break loop
default:
resp, err := c.svc.GetRecords(
resp, err := c.client.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil {
shardIterator, err = c.getShardIterator(shardID)
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil {
logger.WithError(err).Error("getShardIterator")
break loop
@ -174,21 +166,21 @@ loop:
case <-ctx.Done():
break loop
default:
sequenceNumber = *r.SequenceNumber
lastSeqNum = *r.SequenceNumber
if ok := fn(r); !ok {
break loop
}
}
}
logger.WithField("records", len(resp.Records)).Info("checkpoint")
if err := c.checkpoint.Set(shardID, sequenceNumber); err != nil {
logger.WithField("count", len(resp.Records)).Info("checkpoint")
if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
c.logger.WithError(err).Error("set checkpoint error")
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator, err = c.getShardIterator(shardID)
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil {
logger.WithError(err).Error("getShardIterator")
break loop
@ -199,32 +191,29 @@ loop:
}
}
if sequenceNumber != "" {
if err := c.checkpoint.Set(shardID, sequenceNumber); err != nil {
c.logger.WithError(err).Error("set checkpoint error")
if lastSeqNum == "" {
return
}
if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
c.logger.WithError(err).Error("set checkpoint error")
}
}
func (c *Consumer) getShardIterator(shardID string) (*string, error) {
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(c.streamName),
}
seqNum, err := c.checkpoint.Get(shardID)
if err != nil {
return nil, err
}
if seqNum != "" {
if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(seqNum)
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.svc.GetShardIterator(params)
resp, err := c.client.GetShardIterator(params)
if err != nil {
c.logger.WithError(err).Error("GetShardIterator")
return nil, err

View file

@ -8,8 +8,8 @@ import (
"github.com/apex/log"
"github.com/apex/log/handlers/text"
"github.com/aws/aws-sdk-go/service/kinesis"
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
)
func main() {
@ -22,12 +22,20 @@ func main() {
)
flag.Parse()
c, err := consumer.New(*app, *stream)
// new checkpoint
ck, err := checkpoint.New(*app, *stream)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
// new consumer
c, err := consumer.New(ck, *app, *stream)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
err = c.Scan(context.TODO(), func(r *kinesis.Record) bool {
// scan stream
err = c.Scan(context.TODO(), func(r *consumer.Record) bool {
fmt.Println(string(r.Data))
return true // continue scanning
})

View file

@ -18,7 +18,7 @@ func main() {
log.SetHandler(text.New(os.Stderr))
log.SetLevel(log.DebugLevel)
var streamName = flag.String("s", "", "Stream name")
var streamName = flag.String("stream", "", "Stream name")
flag.Parse()
// download file with test data