Simplify checkpoint interface; reduce input vars

This commit is contained in:
Harlow Ward 2017-11-22 20:01:31 -08:00
parent 3f081bd05a
commit 6401371727
6 changed files with 64 additions and 81 deletions

View file

@ -28,19 +28,16 @@ import(
) )
func main() { func main() {
var ( var stream = flag.String("stream", "", "Stream name")
app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
)
flag.Parse() flag.Parse()
// consumer // consumer
c, err := consumer.New(*app, *stream) c, err := consumer.New(*stream)
if err != nil { if err != nil {
log.Fatalf("consumer error: %v", err) log.Fatalf("consumer error: %v", err)
} }
// scan stream // start
err = c.Scan(context.TODO(), func(r *consumer.Record) bool { err = c.Scan(context.TODO(), func(r *consumer.Record) bool {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
return true // continue scanning return true // continue scanning
@ -76,7 +73,7 @@ The Redis checkpoint requries App Name, and Stream Name:
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis" import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
// redis checkpoint // redis checkpoint
ck, err := checkpoint.New(appName, streamName) ck, err := checkpoint.New(appName)
if err != nil { if err != nil {
log.Fatalf("new checkpoint error: %v", err) log.Fatalf("new checkpoint error: %v", err)
} }
@ -90,7 +87,7 @@ The DynamoDB checkpoint requires Table Name, App Name, and Stream Name:
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb" import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb"
// ddb checkpoint // ddb checkpoint
ck, err := checkpoint.New(tableName, appName, streamName) ck, err := checkpoint.New(tableName, appName)
if err != nil { if err != nil {
log.Fatalf("new checkpoint error: %v", err) log.Fatalf("new checkpoint error: %v", err)
} }
@ -98,7 +95,12 @@ if err != nil {
To leverage the DDB checkpoint we'll also need to create a table: To leverage the DDB checkpoint we'll also need to create a table:
<img width="659" alt="screen shot 2017-11-20 at 9 16 14 am" src="https://user-images.githubusercontent.com/739782/33033316-db85f848-cdd8-11e7-941a-0a87d8ace479.png"> ```
Partition key: namespace
Sort key: shard_id
```
<img width="727" alt="screen shot 2017-11-22 at 7 59 36 pm" src="https://user-images.githubusercontent.com/739782/33158557-b90e4228-cfbf-11e7-9a99-73b56a446f5f.png">
## Options ## Options
@ -113,9 +115,7 @@ Override the Kinesis client if there is any special config needed:
client := kinesis.New(session.New(aws.NewConfig())) client := kinesis.New(session.New(aws.NewConfig()))
// consumer // consumer
c, err := consumer.New(appName, streamName, c, err := consumer.New(streamName, consumer.WithClient(client))
consumer.WithClient(client),
)
``` ```
### Metrics ### Metrics
@ -127,9 +127,7 @@ Add optional counter for exposing counts for checkpoints and records processed:
counter := expvar.NewMap("counters") counter := expvar.NewMap("counters")
// consumer // consumer
c, err := consumer.New(appName, streamName, c, err := consumer.New(streamName, consumer.WithCounter(counter))
consumer.WithCounter(counter),
)
``` ```
The [expvar package](https://golang.org/pkg/expvar/) will display consumer counts: The [expvar package](https://golang.org/pkg/expvar/) will display consumer counts:
@ -150,9 +148,7 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom
logger := log.New(os.Stdout, "consumer-example: ", log.LstdFlags) logger := log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
// consumer // consumer
c, err := consumer.New(appName, streamName, c, err := consumer.New(streamName, consumer.WithLogger(logger))
consumer.WithLogger(logger),
)
``` ```
## Contributing ## Contributing

View file

@ -12,7 +12,7 @@ import (
) )
// New returns a checkpoint that uses DynamoDB for underlying storage // New returns a checkpoint that uses DynamoDB for underlying storage
func New(tableName, appName, streamName string) (*Checkpoint, error) { func New(tableName, appName string) (*Checkpoint, error) {
client := dynamodb.New(session.New(aws.NewConfig())) client := dynamodb.New(session.New(aws.NewConfig()))
_, err := client.DescribeTable(&dynamodb.DescribeTableInput{ _, err := client.DescribeTable(&dynamodb.DescribeTableInput{
@ -23,24 +23,21 @@ func New(tableName, appName, streamName string) (*Checkpoint, error) {
} }
return &Checkpoint{ return &Checkpoint{
TableName: tableName, tableName: tableName,
AppName: appName, appName: appName,
StreamName: streamName, client: client,
client: client,
}, nil }, nil
} }
// Checkpoint stores and retreives the last evaluated key from a DDB scan // Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct { type Checkpoint struct {
AppName string tableName string
StreamName string appName string
TableName string client *dynamodb.DynamoDB
client *dynamodb.DynamoDB
} }
type item struct { type item struct {
ConsumerGroup string `json:"consumer_group"` Namespace string `json:"namespace"`
ShardID string `json:"shard_id"` ShardID string `json:"shard_id"`
SequenceNumber string `json:"sequence_number"` SequenceNumber string `json:"sequence_number"`
} }
@ -48,13 +45,15 @@ type item struct {
// Get determines if a checkpoint for a particular Shard exists. // Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with // Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) Get(shardID string) (string, error) { func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
params := &dynamodb.GetItemInput{ params := &dynamodb.GetItemInput{
TableName: aws.String(c.TableName), TableName: aws.String(c.tableName),
ConsistentRead: aws.Bool(true), ConsistentRead: aws.Bool(true),
Key: map[string]*dynamodb.AttributeValue{ Key: map[string]*dynamodb.AttributeValue{
"consumer_group": &dynamodb.AttributeValue{ "namespace": &dynamodb.AttributeValue{
S: aws.String(c.consumerGroupName()), S: aws.String(namespace),
}, },
"shard_id": &dynamodb.AttributeValue{ "shard_id": &dynamodb.AttributeValue{
S: aws.String(shardID), S: aws.String(shardID),
@ -65,7 +64,7 @@ func (c *Checkpoint) Get(shardID string) (string, error) {
resp, err := c.client.GetItem(params) resp, err := c.client.GetItem(params)
if err != nil { if err != nil {
if retriableError(err) { if retriableError(err) {
return c.Get(shardID) return c.Get(streamName, shardID)
} }
return "", err return "", err
} }
@ -77,13 +76,15 @@ func (c *Checkpoint) Get(shardID string) (string, error) {
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point. // Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(shardID string, sequenceNumber string) error { func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" { if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty") return fmt.Errorf("sequence number should not be empty")
} }
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
item, err := dynamodbattribute.MarshalMap(item{ item, err := dynamodbattribute.MarshalMap(item{
ConsumerGroup: c.consumerGroupName(), Namespace: namespace,
ShardID: shardID, ShardID: shardID,
SequenceNumber: sequenceNumber, SequenceNumber: sequenceNumber,
}) })
@ -93,22 +94,18 @@ func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
} }
_, err = c.client.PutItem(&dynamodb.PutItemInput{ _, err = c.client.PutItem(&dynamodb.PutItemInput{
TableName: aws.String(c.TableName), TableName: aws.String(c.tableName),
Item: item, Item: item,
}) })
if err != nil { if err != nil {
if !retriableError(err) { if !retriableError(err) {
return err return err
} }
return c.Set(shardID, sequenceNumber) return c.Set(streamName, shardID, sequenceNumber)
} }
return nil return nil
} }
func (c *Checkpoint) consumerGroupName() string {
return fmt.Sprintf("%s-%s", c.StreamName, c.AppName)
}
func retriableError(err error) bool { func retriableError(err error) bool {
if awsErr, ok := err.(awserr.Error); ok { if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() == "ProvisionedThroughputExceededException" { if awsErr.Code() == "ProvisionedThroughputExceededException" {

View file

@ -10,7 +10,7 @@ import (
const localhost = "127.0.0.1:6379" const localhost = "127.0.0.1:6379"
// New returns a checkpoint that uses Redis for underlying storage // New returns a checkpoint that uses Redis for underlying storage
func New(appName, streamName string) (*Checkpoint, error) { func New(appName string) (*Checkpoint, error) {
addr := os.Getenv("REDIS_URL") addr := os.Getenv("REDIS_URL")
if addr == "" { if addr == "" {
addr = localhost addr = localhost
@ -25,32 +25,30 @@ func New(appName, streamName string) (*Checkpoint, error) {
} }
return &Checkpoint{ return &Checkpoint{
appName: appName, appName: appName,
streamName: streamName, client: client,
client: client,
}, nil }, nil
} }
// Checkpoint stores and retreives the last evaluated key from a DDB scan // Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct { type Checkpoint struct {
appName string appName string
streamName string client *redis.Client
client *redis.Client
} }
// Get fetches the checkpoint for a particular Shard. // Get fetches the checkpoint for a particular Shard.
func (c *Checkpoint) Get(shardID string) (string, error) { func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
val, _ := c.client.Get(c.key(shardID)).Result() val, _ := c.client.Get(c.key(streamName, shardID)).Result()
return val, nil return val, nil
} }
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point. // Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(shardID string, sequenceNumber string) error { func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" { if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty") return fmt.Errorf("sequence number should not be empty")
} }
err := c.client.Set(c.key(shardID), sequenceNumber, 0).Err() err := c.client.Set(c.key(streamName, shardID), sequenceNumber, 0).Err()
if err != nil { if err != nil {
return err return err
} }
@ -58,6 +56,6 @@ func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
} }
// key generates a unique Redis key for storage of Checkpoint. // key generates a unique Redis key for storage of Checkpoint.
func (c *Checkpoint) key(shardID string) string { func (c *Checkpoint) key(streamName, shardID string) string {
return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, c.streamName, shardID) return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, streamName, shardID)
} }

View file

@ -12,16 +12,15 @@ func Test_CheckpointLifecycle(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr}) client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{ c := &Checkpoint{
appName: "app", appName: "app",
streamName: "stream", client: client,
client: client,
} }
// set checkpoint // set checkpoint
c.Set("shard_id", "testSeqNum") c.Set("streamName", "shardID", "testSeqNum")
// get checkpoint // get checkpoint
val, err := c.Get("shard_id") val, err := c.Get("streamName", "shardID")
if err != nil { if err != nil {
t.Fatalf("get checkpoint error: %v", err) t.Fatalf("get checkpoint error: %v", err)
} }
@ -30,21 +29,20 @@ func Test_CheckpointLifecycle(t *testing.T) {
t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val) t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val)
} }
client.Del(c.key("shard_id")) client.Del(c.key("streamName", "shardID"))
} }
func Test_key(t *testing.T) { func Test_key(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr}) client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{ c := &Checkpoint{
appName: "app", appName: "app",
streamName: "stream", client: client,
client: client,
} }
expected := "app:checkpoint:stream:shard" expected := "app:checkpoint:stream:shard"
if val := c.key("shard"); val != expected { if val := c.key("stream", "shard"); val != expected {
t.Fatalf("checkpoint exists expected %s, got %s", expected, val) t.Fatalf("checkpoint exists expected %s, got %s", expected, val)
} }
} }

View file

@ -25,14 +25,14 @@ func (n noopCounter) Add(string, int64) {}
// Checkpoint interface used track consumer progress in the stream // Checkpoint interface used track consumer progress in the stream
type Checkpoint interface { type Checkpoint interface {
Get(shardID string) (string, error) Get(streamName, shardID string) (string, error)
Set(shardID string, sequenceNumber string) error Set(streamName, shardID, sequenceNumber string) error
} }
type noopCheckpoint struct{} type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string) error { return nil } func (n noopCheckpoint) Set(string, string, string) error { return nil }
func (n noopCheckpoint) Get(string) (string, error) { return "", nil } func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }
// Option is used to override defaults when creating a new Consumer // Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error type Option func(*Consumer) error
@ -63,17 +63,12 @@ func WithCounter(counter Counter) Option {
// New creates a kinesis consumer with default settings. Use Option to override // New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes. // any of the optional attributes.
func New(app, stream string, opts ...Option) (*Consumer, error) { func New(stream string, opts ...Option) (*Consumer, error) {
if app == "" {
return nil, fmt.Errorf("must provide app name")
}
if stream == "" { if stream == "" {
return nil, fmt.Errorf("must provide stream name") return nil, fmt.Errorf("must provide stream name")
} }
c := &Consumer{ c := &Consumer{
appName: app,
streamName: stream, streamName: stream,
checkpoint: &noopCheckpoint{}, checkpoint: &noopCheckpoint{},
counter: &noopCounter{}, counter: &noopCounter{},
@ -97,7 +92,6 @@ func New(app, stream string, opts ...Option) (*Consumer, error) {
// Consumer wraps the interaction with the Kinesis stream // Consumer wraps the interaction with the Kinesis stream
type Consumer struct { type Consumer struct {
appName string
streamName string streamName string
client *kinesis.Kinesis client *kinesis.Kinesis
logger *log.Logger logger *log.Logger
@ -141,7 +135,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
// for each record and checkpoints after each page is processed. // for each record and checkpoints after each page is processed.
// Note: returning `false` from the callback func will end the scan. // Note: returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
lastSeqNum, err := c.checkpoint.Get(shardID) lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil { if err != nil {
c.logger.Printf("get checkpoint error: %v", err) c.logger.Printf("get checkpoint error: %v", err)
return return
@ -190,7 +184,7 @@ loop:
} }
} }
if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil { if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err) c.logger.Printf("set checkpoint error: %v", err)
} }
@ -215,7 +209,7 @@ loop:
} }
c.logger.Println("checkpointing", shardID) c.logger.Println("checkpointing", shardID)
if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil { if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err) c.logger.Printf("set checkpoint error: %v", err)
} }
} }

View file

@ -39,13 +39,13 @@ func main() {
) )
// checkpoint // checkpoint
ck, err := checkpoint.New(*app, *stream) ck, err := checkpoint.New(*app)
if err != nil { if err != nil {
log.Fatalf("checkpoint error: %v", err) log.Fatalf("checkpoint error: %v", err)
} }
// consumer // consumer
c, err := consumer.New(*app, *stream, c, err := consumer.New(*stream,
consumer.WithCheckpoint(ck), consumer.WithCheckpoint(ck),
consumer.WithLogger(logger), consumer.WithLogger(logger),
consumer.WithCounter(counter), consumer.WithCounter(counter),