Simplify checkpoint interface; reduce input vars

This commit is contained in:
Harlow Ward 2017-11-22 20:01:31 -08:00
parent 3f081bd05a
commit 6401371727
6 changed files with 64 additions and 81 deletions

View file

@ -28,19 +28,16 @@ import(
)
func main() {
var (
app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
)
var stream = flag.String("stream", "", "Stream name")
flag.Parse()
// consumer
c, err := consumer.New(*app, *stream)
c, err := consumer.New(*stream)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
// scan stream
// start
err = c.Scan(context.TODO(), func(r *consumer.Record) bool {
fmt.Println(string(r.Data))
return true // continue scanning
@ -76,7 +73,7 @@ The Redis checkpoint requries App Name, and Stream Name:
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
// redis checkpoint
ck, err := checkpoint.New(appName, streamName)
ck, err := checkpoint.New(appName)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
@ -90,7 +87,7 @@ The DynamoDB checkpoint requires Table Name, App Name, and Stream Name:
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb"
// ddb checkpoint
ck, err := checkpoint.New(tableName, appName, streamName)
ck, err := checkpoint.New(tableName, appName)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
@ -98,7 +95,12 @@ if err != nil {
To leverage the DDB checkpoint we'll also need to create a table:
<img width="659" alt="screen shot 2017-11-20 at 9 16 14 am" src="https://user-images.githubusercontent.com/739782/33033316-db85f848-cdd8-11e7-941a-0a87d8ace479.png">
```
Partition key: namespace
Sort key: shard_id
```
<img width="727" alt="screen shot 2017-11-22 at 7 59 36 pm" src="https://user-images.githubusercontent.com/739782/33158557-b90e4228-cfbf-11e7-9a99-73b56a446f5f.png">
## Options
@ -113,9 +115,7 @@ Override the Kinesis client if there is any special config needed:
client := kinesis.New(session.New(aws.NewConfig()))
// consumer
c, err := consumer.New(appName, streamName,
consumer.WithClient(client),
)
c, err := consumer.New(streamName, consumer.WithClient(client))
```
### Metrics
@ -127,9 +127,7 @@ Add optional counter for exposing counts for checkpoints and records processed:
counter := expvar.NewMap("counters")
// consumer
c, err := consumer.New(appName, streamName,
consumer.WithCounter(counter),
)
c, err := consumer.New(streamName, consumer.WithCounter(counter))
```
The [expvar package](https://golang.org/pkg/expvar/) will display consumer counts:
@ -150,9 +148,7 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom
logger := log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
// consumer
c, err := consumer.New(appName, streamName,
consumer.WithLogger(logger),
)
c, err := consumer.New(streamName, consumer.WithLogger(logger))
```
## Contributing

View file

@ -12,7 +12,7 @@ import (
)
// New returns a checkpoint that uses DynamoDB for underlying storage
func New(tableName, appName, streamName string) (*Checkpoint, error) {
func New(tableName, appName string) (*Checkpoint, error) {
client := dynamodb.New(session.New(aws.NewConfig()))
_, err := client.DescribeTable(&dynamodb.DescribeTableInput{
@ -23,24 +23,21 @@ func New(tableName, appName, streamName string) (*Checkpoint, error) {
}
return &Checkpoint{
TableName: tableName,
AppName: appName,
StreamName: streamName,
client: client,
tableName: tableName,
appName: appName,
client: client,
}, nil
}
// Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct {
AppName string
StreamName string
TableName string
client *dynamodb.DynamoDB
tableName string
appName string
client *dynamodb.DynamoDB
}
type item struct {
ConsumerGroup string `json:"consumer_group"`
Namespace string `json:"namespace"`
ShardID string `json:"shard_id"`
SequenceNumber string `json:"sequence_number"`
}
@ -48,13 +45,15 @@ type item struct {
// Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) Get(shardID string) (string, error) {
func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
params := &dynamodb.GetItemInput{
TableName: aws.String(c.TableName),
TableName: aws.String(c.tableName),
ConsistentRead: aws.Bool(true),
Key: map[string]*dynamodb.AttributeValue{
"consumer_group": &dynamodb.AttributeValue{
S: aws.String(c.consumerGroupName()),
"namespace": &dynamodb.AttributeValue{
S: aws.String(namespace),
},
"shard_id": &dynamodb.AttributeValue{
S: aws.String(shardID),
@ -65,7 +64,7 @@ func (c *Checkpoint) Get(shardID string) (string, error) {
resp, err := c.client.GetItem(params)
if err != nil {
if retriableError(err) {
return c.Get(shardID)
return c.Get(streamName, shardID)
}
return "", err
}
@ -77,13 +76,15 @@ func (c *Checkpoint) Get(shardID string) (string, error) {
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
item, err := dynamodbattribute.MarshalMap(item{
ConsumerGroup: c.consumerGroupName(),
Namespace: namespace,
ShardID: shardID,
SequenceNumber: sequenceNumber,
})
@ -93,22 +94,18 @@ func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
}
_, err = c.client.PutItem(&dynamodb.PutItemInput{
TableName: aws.String(c.TableName),
TableName: aws.String(c.tableName),
Item: item,
})
if err != nil {
if !retriableError(err) {
return err
}
return c.Set(shardID, sequenceNumber)
return c.Set(streamName, shardID, sequenceNumber)
}
return nil
}
func (c *Checkpoint) consumerGroupName() string {
return fmt.Sprintf("%s-%s", c.StreamName, c.AppName)
}
func retriableError(err error) bool {
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() == "ProvisionedThroughputExceededException" {

View file

@ -10,7 +10,7 @@ import (
const localhost = "127.0.0.1:6379"
// New returns a checkpoint that uses Redis for underlying storage
func New(appName, streamName string) (*Checkpoint, error) {
func New(appName string) (*Checkpoint, error) {
addr := os.Getenv("REDIS_URL")
if addr == "" {
addr = localhost
@ -25,32 +25,30 @@ func New(appName, streamName string) (*Checkpoint, error) {
}
return &Checkpoint{
appName: appName,
streamName: streamName,
client: client,
appName: appName,
client: client,
}, nil
}
// Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct {
appName string
streamName string
client *redis.Client
appName string
client *redis.Client
}
// Get fetches the checkpoint for a particular Shard.
func (c *Checkpoint) Get(shardID string) (string, error) {
val, _ := c.client.Get(c.key(shardID)).Result()
func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
val, _ := c.client.Get(c.key(streamName, shardID)).Result()
return val, nil
}
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
err := c.client.Set(c.key(shardID), sequenceNumber, 0).Err()
err := c.client.Set(c.key(streamName, shardID), sequenceNumber, 0).Err()
if err != nil {
return err
}
@ -58,6 +56,6 @@ func (c *Checkpoint) Set(shardID string, sequenceNumber string) error {
}
// key generates a unique Redis key for storage of Checkpoint.
func (c *Checkpoint) key(shardID string) string {
return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, c.streamName, shardID)
func (c *Checkpoint) key(streamName, shardID string) string {
return fmt.Sprintf("%v:checkpoint:%v:%v", c.appName, streamName, shardID)
}

View file

@ -12,16 +12,15 @@ func Test_CheckpointLifecycle(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{
appName: "app",
streamName: "stream",
client: client,
appName: "app",
client: client,
}
// set checkpoint
c.Set("shard_id", "testSeqNum")
c.Set("streamName", "shardID", "testSeqNum")
// get checkpoint
val, err := c.Get("shard_id")
val, err := c.Get("streamName", "shardID")
if err != nil {
t.Fatalf("get checkpoint error: %v", err)
}
@ -30,21 +29,20 @@ func Test_CheckpointLifecycle(t *testing.T) {
t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val)
}
client.Del(c.key("shard_id"))
client.Del(c.key("streamName", "shardID"))
}
func Test_key(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: defaultAddr})
c := &Checkpoint{
appName: "app",
streamName: "stream",
client: client,
appName: "app",
client: client,
}
expected := "app:checkpoint:stream:shard"
if val := c.key("shard"); val != expected {
if val := c.key("stream", "shard"); val != expected {
t.Fatalf("checkpoint exists expected %s, got %s", expected, val)
}
}

View file

@ -25,14 +25,14 @@ func (n noopCounter) Add(string, int64) {}
// Checkpoint interface used track consumer progress in the stream
type Checkpoint interface {
Get(shardID string) (string, error)
Set(shardID string, sequenceNumber string) error
Get(streamName, shardID string) (string, error)
Set(streamName, shardID, sequenceNumber string) error
}
type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string) error { return nil }
func (n noopCheckpoint) Get(string) (string, error) { return "", nil }
func (n noopCheckpoint) Set(string, string, string) error { return nil }
func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error
@ -63,17 +63,12 @@ func WithCounter(counter Counter) Option {
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(app, stream string, opts ...Option) (*Consumer, error) {
if app == "" {
return nil, fmt.Errorf("must provide app name")
}
func New(stream string, opts ...Option) (*Consumer, error) {
if stream == "" {
return nil, fmt.Errorf("must provide stream name")
}
c := &Consumer{
appName: app,
streamName: stream,
checkpoint: &noopCheckpoint{},
counter: &noopCounter{},
@ -97,7 +92,6 @@ func New(app, stream string, opts ...Option) (*Consumer, error) {
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
appName string
streamName string
client *kinesis.Kinesis
logger *log.Logger
@ -141,7 +135,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
// for each record and checkpoints after each page is processed.
// Note: returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
lastSeqNum, err := c.checkpoint.Get(shardID)
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil {
c.logger.Printf("get checkpoint error: %v", err)
return
@ -190,7 +184,7 @@ loop:
}
}
if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err)
}
@ -215,7 +209,7 @@ loop:
}
c.logger.Println("checkpointing", shardID)
if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err)
}
}

View file

@ -39,13 +39,13 @@ func main() {
)
// checkpoint
ck, err := checkpoint.New(*app, *stream)
ck, err := checkpoint.New(*app)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
// consumer
c, err := consumer.New(*app, *stream,
c, err := consumer.New(*stream,
consumer.WithCheckpoint(ck),
consumer.WithLogger(logger),
consumer.WithCounter(counter),