diff --git a/README.md b/README.md
index 9dd9ee6..a28bd50 100644
--- a/README.md
+++ b/README.md
@@ -28,19 +28,16 @@ import(
)
func main() {
- var (
- app = flag.String("app", "", "App name")
- stream = flag.String("stream", "", "Stream name")
- )
+ var stream = flag.String("stream", "", "Stream name")
flag.Parse()
// consumer
- c, err := consumer.New(*app, *stream)
+ c, err := consumer.New(*stream)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
- // scan stream
+ // start
err = c.Scan(context.TODO(), func(r *consumer.Record) bool {
fmt.Println(string(r.Data))
return true // continue scanning
@@ -76,7 +73,7 @@ The Redis checkpoint requries App Name, and Stream Name:
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
// redis checkpoint
-ck, err := checkpoint.New(appName, streamName)
+ck, err := checkpoint.New(appName)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
@@ -90,7 +87,7 @@ The DynamoDB checkpoint requires Table Name, App Name, and Stream Name:
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb"
// ddb checkpoint
-ck, err := checkpoint.New(tableName, appName, streamName)
+ck, err := checkpoint.New(tableName, appName)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
@@ -98,7 +95,12 @@ if err != nil {
To leverage the DDB checkpoint we'll also need to create a table:
-
+```
+Partition key: namespace
+Sort key: shard_id
+```
+
+
## Options
@@ -113,9 +115,7 @@ Override the Kinesis client if there is any special config needed:
client := kinesis.New(session.New(aws.NewConfig()))
// consumer
-c, err := consumer.New(appName, streamName,
- consumer.WithClient(client),
-)
+c, err := consumer.New(streamName, consumer.WithClient(client))
```
### Metrics
@@ -127,9 +127,7 @@ Add optional counter for exposing counts for checkpoints and records processed:
counter := expvar.NewMap("counters")
// consumer
-c, err := consumer.New(appName, streamName,
- consumer.WithCounter(counter),
-)
+c, err := consumer.New(streamName, consumer.WithCounter(counter))
```
The [expvar package](https://golang.org/pkg/expvar/) will display consumer counts:
@@ -150,9 +148,7 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom
logger := log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
// consumer
-c, err := consumer.New(appName, streamName,
- consumer.WithLogger(logger),
-)
+c, err := consumer.New(streamName, consumer.WithLogger(logger))
```
## Contributing
diff --git a/checkpoint/ddb/ddb.go b/checkpoint/ddb/ddb.go
index 84014a0..1c0daa2 100644
--- a/checkpoint/ddb/ddb.go
+++ b/checkpoint/ddb/ddb.go
@@ -12,7 +12,7 @@ import (
)
// New returns a checkpoint that uses DynamoDB for underlying storage
-func New(tableName, appName, streamName string) (*Checkpoint, error) {
+func New(tableName, appName string) (*Checkpoint, error) {
client := dynamodb.New(session.New(aws.NewConfig()))
_, err := client.DescribeTable(&dynamodb.DescribeTableInput{
@@ -23,24 +23,21 @@ func New(tableName, appName, streamName string) (*Checkpoint, error) {
}
return &Checkpoint{
- TableName: tableName,
- AppName: appName,
- StreamName: streamName,
- client: client,
+ tableName: tableName,
+ appName: appName,
+ client: client,
}, nil
}
// Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct {
- AppName string
- StreamName string
- TableName string
-
- client *dynamodb.DynamoDB
+ tableName string
+ appName string
+ client *dynamodb.DynamoDB
}
type item struct {
- ConsumerGroup string `json:"consumer_group"`
+ Namespace string `json:"namespace"`
ShardID string `json:"shard_id"`
SequenceNumber string `json:"sequence_number"`
}
@@ -48,13 +45,15 @@ type item struct {
// Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
-func (c *Checkpoint) Get(shardID string) (string, error) {
+func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
+ namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
+
params := &dynamodb.GetItemInput{
- TableName: aws.String(c.TableName),
+ TableName: aws.String(c.tableName),
ConsistentRead: aws.Bool(true),
Key: map[string]*dynamodb.AttributeValue{
- "consumer_group": &dynamodb.AttributeValue{
- S: aws.String(c.consumerGroupName()),
+ "namespace": &dynamodb.AttributeValue{
+ S: aws.String(namespace),
},
"shard_id": &dynamodb.AttributeValue{
S: aws.String(shardID),
@@ -65,7 +64,7 @@ func (c *Checkpoint) Get(shardID string) (string, error) {
resp, err := c.client.GetItem(params)
if err != nil {
if retriableError(err) {
- return c.Get(shardID)
+ return c.Get(streamName, shardID)
}
return "", err
}
@@ -77,13 +76,15 @@ func (c *Checkpoint) Get(shardID string) (string, error) {
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
-func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
+func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
+ namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
+
item, err := dynamodbattribute.MarshalMap(item{
- ConsumerGroup: c.consumerGroupName(),
+ Namespace: namespace,
ShardID: shardID,
SequenceNumber: sequenceNumber,
})
@@ -93,22 +94,18 @@ func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
}
_, err = c.client.PutItem(&dynamodb.PutItemInput{
- TableName: aws.String(c.TableName),
+ TableName: aws.String(c.tableName),
Item: item,
})
if err != nil {
if !retriableError(err) {
return err
}
- return c.Set(shardID, sequenceNumber)
+ return c.Set(streamName, shardID, sequenceNumber)
}
return nil
}
-func (c *Checkpoint) consumerGroupName() string {
- return fmt.Sprintf("%s-%s", c.StreamName, c.AppName)
-}
-
func retriableError(err error) bool {
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() == "ProvisionedThroughputExceededException" {
diff --git a/checkpoint/redis/redis.go b/checkpoint/redis/redis.go
index 6842788..e3f7e51 100644
--- a/checkpoint/redis/redis.go
+++ b/checkpoint/redis/redis.go
@@ -10,7 +10,7 @@ import (
const localhost = "127.0.0.1:6379"
// New returns a checkpoint that uses Redis for underlying storage
-func New(appName, streamName string) (*Checkpoint, error) {
+func New(appName string) (*Checkpoint, error) {
addr := os.Getenv("REDIS_URL")
if addr == "" {
addr = localhost
@@ -25,32 +25,30 @@ func New(appName, streamName string) (*Checkpoint, error) {
}
return &Checkpoint{
- appName: appName,
- streamName: streamName,
- client: client,
+ appName: appName,
+ client: client,
}, nil
}
// Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct {
- appName string
- streamName string
- client *redis.Client
+ appName string
+ client *redis.Client
}
// Get fetches the checkpoint for a particular Shard.
-func (c *Checkpoint) Get(shardID string) (string, error) {
- val, _ := c.client.Get(c.key(shardID)).Result()
+func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
+ val, _ := c.client.Get(c.key(streamName, shardID)).Result()
return val, nil
}
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
-func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
+func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
- err := c.client.Set(c.key(shardID), sequenceNumber, 0).Err()
+ err := c.client.Set(c.key(streamName, shardID), sequenceNumber, 0).Err()
if err != nil {
return err
}
@@ -58,6 +56,6 @@ func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
}
// key generates a unique Redis key for storage of Checkpoint.
-func (c *Checkpoint) key(shardID string) string {
- return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, c.streamName, shardID)
+func (c *Checkpoint) key(streamName, shardID string) string {
+ return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, streamName, shardID)
}
diff --git a/checkpoint/redis/redis_test.go b/checkpoint/redis/redis_test.go
index 7c49190..96ed617 100644
--- a/checkpoint/redis/redis_test.go
+++ b/checkpoint/redis/redis_test.go
@@ -12,16 +12,15 @@ func Test_CheckpointLifecycle(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{
- appName: "app",
- streamName: "stream",
- client: client,
+ appName: "app",
+ client: client,
}
// set checkpoint
- c.Set("shard_id", "testSeqNum")
+ c.Set("streamName", "shardID", "testSeqNum")
// get checkpoint
- val, err := c.Get("shard_id")
+ val, err := c.Get("streamName", "shardID")
if err != nil {
t.Fatalf("get checkpoint error: %v", err)
}
@@ -30,21 +29,20 @@ func Test_CheckpointLifecycle(t *testing.T) {
t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val)
}
- client.Del(c.key("shard_id"))
+ client.Del(c.key("streamName", "shardID"))
}
func Test_key(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{
- appName: "app",
- streamName: "stream",
- client: client,
+ appName: "app",
+ client: client,
}
expected := "app:checkpoint:stream:shard"
- if val := c.key("shard"); val != expected {
+ if val := c.key("stream", "shard"); val != expected {
t.Fatalf("checkpoint exists expected %s, got %s", expected, val)
}
}
diff --git a/consumer.go b/consumer.go
index fcc5f89..2e0e688 100644
--- a/consumer.go
+++ b/consumer.go
@@ -25,14 +25,14 @@ func (n noopCounter) Add(string, int64) {}
// Checkpoint interface used track consumer progress in the stream
type Checkpoint interface {
- Get(shardID string) (string, error)
- Set(shardID string, sequenceNumber string) error
+ Get(streamName, shardID string) (string, error)
+ Set(streamName, shardID, sequenceNumber string) error
}
type noopCheckpoint struct{}
-func (n noopCheckpoint) Set(string, string) error { return nil }
-func (n noopCheckpoint) Get(string) (string, error) { return "", nil }
+func (n noopCheckpoint) Set(string, string, string) error { return nil }
+func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error
@@ -63,17 +63,12 @@ func WithCounter(counter Counter) Option {
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
-func New(app, stream string, opts ...Option) (*Consumer, error) {
- if app == "" {
- return nil, fmt.Errorf("must provide app name")
- }
-
+func New(stream string, opts ...Option) (*Consumer, error) {
if stream == "" {
return nil, fmt.Errorf("must provide stream name")
}
c := &Consumer{
- appName: app,
streamName: stream,
checkpoint: &noopCheckpoint{},
counter: &noopCounter{},
@@ -97,7 +92,6 @@ func New(app, stream string, opts ...Option) (*Consumer, error) {
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
- appName string
streamName string
client *kinesis.Kinesis
logger *log.Logger
@@ -141,7 +135,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
// for each record and checkpoints after each page is processed.
// Note: returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
- lastSeqNum, err := c.checkpoint.Get(shardID)
+ lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil {
c.logger.Printf("get checkpoint error: %v", err)
return
@@ -190,7 +184,7 @@ loop:
}
}
- if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
+ if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err)
}
@@ -215,7 +209,7 @@ loop:
}
c.logger.Println("checkpointing", shardID)
- if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
+ if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err)
}
}
diff --git a/examples/consumer/main.go b/examples/consumer/main.go
index 46c6282..1c9ac47 100644
--- a/examples/consumer/main.go
+++ b/examples/consumer/main.go
@@ -39,13 +39,13 @@ func main() {
)
// checkpoint
- ck, err := checkpoint.New(*app, *stream)
+ ck, err := checkpoint.New(*app)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
// consumer
- c, err := consumer.New(*app, *stream,
+ c, err := consumer.New(*stream,
consumer.WithCheckpoint(ck),
consumer.WithLogger(logger),
consumer.WithCounter(counter),