diff --git a/README.md b/README.md
index 0c686b9..038ec58 100644
--- a/README.md
+++ b/README.md
@@ -17,26 +17,37 @@ Get the package source:
The consumer leverages a handler func that accepts a Kinesis record. The `Scan` method will consume all shards concurrently and call the callback func as it receives records from the stream.
```go
-import consumer "github.com/harlow/kinesis-consumer"
+import(
+ // ...
+ consumer "github.com/harlow/kinesis-consumer"
+ checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
+)
func main() {
log.SetHandler(text.New(os.Stderr))
log.SetLevel(log.DebugLevel)
var (
- app = flag.String("app", "", "App name") // name of consumer group
+ app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
)
flag.Parse()
- c, err := consumer.New(*app, *stream)
+ // new checkpoint
+ ck, err := checkpoint.New(*app, *stream)
+ if err != nil {
+ log.Fatalf("checkpoint error: %v", err)
+ }
+
+ // new consumer
+ c, err := consumer.New(ck, *app, *stream)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
- err = c.Scan(context.TODO(), func(r *kinesis.Record) bool {
+ // scan stream
+ err = c.Scan(context.TODO(), func(r *consumer.Record) bool {
fmt.Println(string(r.Data))
-
return true // continue scanning
})
if err != nil {
@@ -48,17 +59,55 @@ func main() {
}
```
-### Configuration
+### Checkpoint
-The consumer requires the following config:
+To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard.
-* App Name (used for checkpoints)
-* Stream Name (kinesis stream name)
+This will allow consumers to re-launch and pick up at the position in the stream where they left off.
-It also accepts the following optional overrides:
+The uniq identifier for a consumer is `[appName, streamName, shardID]`
+
+
+
+There are two types of checkpoints:
+
+### Redis
+
+The Redis checkpoint requries App Name, and Stream Name:
+
+```go
+import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
+
+// redis checkpoint
+ck, err := checkpoint.New(appName, streamName)
+if err != nil {
+ log.Fatalf("new checkpoint error: %v", err)
+}
+```
+
+### DynamoDB
+
+The DynamoDB checkpoint requires Table Name, App Name, and Stream Name:
+
+```go
+import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb"
+
+// ddb checkpoint
+ck, err := checkpoint.New(tableName, appName, streamName)
+if err != nil {
+ log.Fatalf("new checkpoint error: %v", err)
+}
+```
+
+To leverage the DDB checkpoint we'll also need to create a table:
+
+
+
+### Options
+
+The consumer allows the following optional overrides:
* Kinesis Client
-* Checkpoint Storage
* Logger
```go
@@ -67,46 +116,12 @@ svc := kinesis.New(session.New(aws.NewConfig()))
// new consumer with custom client
c, err := consumer.New(
- appName,
+ consumer,
streamName,
consumer.WithClient(svc),
)
```
-### Checkpoint Storage
-
-To record the progress of the consumer in the stream we store the last sequence number the consumer has read from a particular shard. This will allow consumers to re-launch and pick up at the position in the stream where they left off.
-
-
-
-
-The default checkpoint uses Redis on localhost; to set a custom Redis URL use ENV vars:
-
-```
-REDIS_URL=redis.yoursite.com:6379
-```
-
-To leverage DynamoDB as the backend for checkpoint we'll need a new table:
-
-
-
-Then override the checkpoint config option:
-
-```go
-// ddb checkpoint
-ck, err := checkpoint.New(tableName, appName, streamName)
-if err != nil {
- log.Fatalf("new checkpoint error: %v", err)
-}
-
-// consumer with checkpoint
-c, err := consumer.New(
- appName,
- streamName,
- consumer.WithCheckpoint(ck),
-)
-```
-
### Logging
[Apex Log](https://medium.com/@tjholowaychuk/apex-log-e8d9627f4a9a#.5x1uo1767) is used for logging Info. Override the logs format with other [Log Handlers](https://github.com/apex/log/tree/master/_examples). For example using the "json" log handler:
diff --git a/checkpoint/ddb/ddb.go b/checkpoint/ddb/ddb.go
index d7628d6..84014a0 100644
--- a/checkpoint/ddb/ddb.go
+++ b/checkpoint/ddb/ddb.go
@@ -78,6 +78,10 @@ func (c *Checkpoint) Get(shardID string) (string, error) {
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
+ if sequenceNumber == "" {
+ return fmt.Errorf("sequence number should not be empty")
+ }
+
item, err := dynamodbattribute.MarshalMap(item{
ConsumerGroup: c.consumerGroupName(),
ShardID: shardID,
diff --git a/checkpoint/redis/redis.go b/checkpoint/redis/redis.go
index c8a3b36..6842788 100644
--- a/checkpoint/redis/redis.go
+++ b/checkpoint/redis/redis.go
@@ -25,38 +25,39 @@ func New(appName, streamName string) (*Checkpoint, error) {
}
return &Checkpoint{
- AppName: appName,
- StreamName: streamName,
+ appName: appName,
+ streamName: streamName,
client: client,
}, nil
}
// Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct {
- AppName string
- StreamName string
-
- client *redis.Client
+ appName string
+ streamName string
+ client *redis.Client
}
-// Get determines if a checkpoint for a particular Shard exists.
-// Typically used to determine whether we should start processing the shard with
-// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
+// Get fetches the checkpoint for a particular Shard.
func (c *Checkpoint) Get(shardID string) (string, error) {
- return c.client.Get(c.key(shardID)).Result()
+ val, _ := c.client.Get(c.key(shardID)).Result()
+ return val, nil
}
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
+ if sequenceNumber == "" {
+ return fmt.Errorf("sequence number should not be empty")
+ }
err := c.client.Set(c.key(shardID), sequenceNumber, 0).Err()
if err != nil {
- return fmt.Errorf("redis checkpoint error: %v", err)
+ return err
}
return nil
}
// key generates a unique Redis key for storage of Checkpoint.
func (c *Checkpoint) key(shardID string) string {
- return fmt.Sprintf("%v:checkpoint:%v:%v", c.AppName, c.StreamName, shardID)
+ return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, c.streamName, shardID)
}
diff --git a/checkpoint/redis/redis_test.go b/checkpoint/redis/redis_test.go
index 0cadec7..7c49190 100644
--- a/checkpoint/redis/redis_test.go
+++ b/checkpoint/redis/redis_test.go
@@ -12,33 +12,33 @@ func Test_CheckpointLifecycle(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{
- AppName: "app",
- StreamName: "stream",
+ appName: "app",
+ streamName: "stream",
client: client,
}
// set checkpoint
- c.SetCheckpoint("shard_id", "testSeqNum")
-
- // checkpoint exists
- if val := c.CheckpointExists("shard_id"); val != true {
- t.Fatalf("checkpoint exists expected true, got %t", val)
- }
+ c.Set("shard_id", "testSeqNum")
// get checkpoint
- if val := c.SequenceNumber(); val != "testSeqNum" {
+ val, err := c.Get("shard_id")
+ if err != nil {
+ t.Fatalf("get checkpoint error: %v", err)
+ }
+
+ if val != "testSeqNum" {
t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val)
}
- client.Del("app:checkpoint:stream:shard_id")
+ client.Del(c.key("shard_id"))
}
func Test_key(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{
- AppName: "app",
- StreamName: "stream",
+ appName: "app",
+ streamName: "stream",
client: client,
}
diff --git a/consumer.go b/consumer.go
index 03f2deb..13f9114 100644
--- a/consumer.go
+++ b/consumer.go
@@ -10,20 +10,13 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/harlow/kinesis-consumer/checkpoint"
- "github.com/harlow/kinesis-consumer/checkpoint/redis"
)
+type Record = kinesis.Record
+
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error
-// WithClient the Kinesis client
-func WithClient(client *kinesis.Kinesis) Option {
- return func(c *Consumer) error {
- c.svc = client
- return nil
- }
-}
-
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint checkpoint.Checkpoint) Option {
return func(c *Consumer) error {
@@ -42,18 +35,23 @@ func WithLogger(logger log.Interface) Option {
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
-func New(appName, streamName string, opts ...Option) (*Consumer, error) {
- if appName == "" {
- return nil, fmt.Errorf("must provide app name to consumer")
+func New(checkpoint checkpoint.Checkpoint, app, stream string, opts ...Option) (*Consumer, error) {
+ if checkpoint == nil {
+ return nil, fmt.Errorf("must provide checkpoint")
}
- if streamName == "" {
- return nil, fmt.Errorf("must provide stream name to consumer")
+ if app == "" {
+ return nil, fmt.Errorf("must provide app name")
+ }
+
+ if stream == "" {
+ return nil, fmt.Errorf("must provide stream name")
}
c := &Consumer{
- appName: appName,
- streamName: streamName,
+ checkpoint: checkpoint,
+ appName: app,
+ streamName: stream,
}
// set options
@@ -67,23 +65,14 @@ func New(appName, streamName string, opts ...Option) (*Consumer, error) {
if c.logger == nil {
c.logger = log.Log.WithFields(log.Fields{
"package": "kinesis-consumer",
- "app": appName,
- "stream": streamName,
+ "app": app,
+ "stream": stream,
})
}
// provide a default kinesis client
- if c.svc == nil {
- c.svc = kinesis.New(session.New(aws.NewConfig()))
- }
-
- // provide default Redis checkpoint
- if c.checkpoint == nil {
- ck, err := redis.New(appName, streamName)
- if err != nil {
- return nil, err
- }
- c.checkpoint = ck
+ if c.client == nil {
+ c.client = kinesis.New(session.New(aws.NewConfig()))
}
return c, nil
@@ -91,9 +80,9 @@ func New(appName, streamName string, opts ...Option) (*Consumer, error) {
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
- appName string
+ appName string
streamName string
- svc *kinesis.Kinesis
+ client *kinesis.Kinesis
logger log.Interface
checkpoint checkpoint.Checkpoint
}
@@ -105,7 +94,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
defer cancel()
// grab the stream details
- resp, err := c.svc.DescribeStream(
+ resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(c.streamName),
},
@@ -134,12 +123,15 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
// for each record and checkpoints after each page is processed.
// Note: returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
- var (
- logger = c.logger.WithFields(log.Fields{"shard": shardID})
- sequenceNumber string
- )
+ var logger = c.logger.WithFields(log.Fields{"shard": shardID})
- shardIterator, err := c.getShardIterator(shardID)
+ lastSeqNum, err := c.checkpoint.Get(shardID)
+ if err != nil {
+ logger.WithError(err).Error("get checkpoint")
+ return
+ }
+
+ shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
if err != nil {
logger.WithError(err).Error("getShardIterator")
return
@@ -153,14 +145,14 @@ loop:
case <-ctx.Done():
break loop
default:
- resp, err := c.svc.GetRecords(
+ resp, err := c.client.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil {
- shardIterator, err = c.getShardIterator(shardID)
+ shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil {
logger.WithError(err).Error("getShardIterator")
break loop
@@ -174,21 +166,21 @@ loop:
case <-ctx.Done():
break loop
default:
- sequenceNumber = *r.SequenceNumber
+ lastSeqNum = *r.SequenceNumber
if ok := fn(r); !ok {
break loop
}
}
}
- logger.WithField("records", len(resp.Records)).Info("checkpoint")
- if err := c.checkpoint.Set(shardID, sequenceNumber); err != nil {
+ logger.WithField("count", len(resp.Records)).Info("checkpoint")
+ if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
c.logger.WithError(err).Error("set checkpoint error")
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
- shardIterator, err = c.getShardIterator(shardID)
+ shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil {
logger.WithError(err).Error("getShardIterator")
break loop
@@ -199,32 +191,29 @@ loop:
}
}
- if sequenceNumber != "" {
- if err := c.checkpoint.Set(shardID, sequenceNumber); err != nil {
- c.logger.WithError(err).Error("set checkpoint error")
- }
+ if lastSeqNum == "" {
+ return
+ }
+
+ if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
+ c.logger.WithError(err).Error("set checkpoint error")
}
}
-func (c *Consumer) getShardIterator(shardID string) (*string, error) {
+func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(c.streamName),
}
- seqNum, err := c.checkpoint.Get(shardID)
- if err != nil {
- return nil, err
- }
-
- if seqNum != "" {
+ if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
- params.StartingSequenceNumber = aws.String(seqNum)
+ params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
- resp, err := c.svc.GetShardIterator(params)
+ resp, err := c.client.GetShardIterator(params)
if err != nil {
c.logger.WithError(err).Error("GetShardIterator")
return nil, err
diff --git a/examples/consumer/main.go b/examples/consumer/main.go
index 8c52270..6613feb 100644
--- a/examples/consumer/main.go
+++ b/examples/consumer/main.go
@@ -8,8 +8,8 @@ import (
"github.com/apex/log"
"github.com/apex/log/handlers/text"
- "github.com/aws/aws-sdk-go/service/kinesis"
consumer "github.com/harlow/kinesis-consumer"
+ checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
)
func main() {
@@ -22,12 +22,20 @@ func main() {
)
flag.Parse()
- c, err := consumer.New(*app, *stream)
+ // new checkpoint
+ ck, err := checkpoint.New(*app, *stream)
+ if err != nil {
+ log.Fatalf("checkpoint error: %v", err)
+ }
+
+ // new consumer
+ c, err := consumer.New(ck, *app, *stream)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
- err = c.Scan(context.TODO(), func(r *kinesis.Record) bool {
+ // scan stream
+ err = c.Scan(context.TODO(), func(r *consumer.Record) bool {
fmt.Println(string(r.Data))
return true // continue scanning
})
diff --git a/examples/producer/main.go b/examples/producer/main.go
index 40d1591..f30efc6 100644
--- a/examples/producer/main.go
+++ b/examples/producer/main.go
@@ -18,7 +18,7 @@ func main() {
log.SetHandler(text.New(os.Stderr))
log.SetLevel(log.DebugLevel)
- var streamName = flag.String("s", "", "Stream name")
+ var streamName = flag.String("stream", "", "Stream name")
flag.Parse()
// download file with test data