kinesis-consumer/consumer.go
Harlow Ward 6ee965ec0a
Add DDB as consumer checkpoint option (#37)
* Simplify the checkpoint interface
* Add DDB backend for checkpoint persistence

Implements: https://github.com/harlow/kinesis-consumer/issues/26
2017-11-20 09:37:30 -08:00

234 lines
5.2 KiB
Go

package consumer
import (
"context"
"fmt"
"os"
"sync"
"github.com/apex/log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/harlow/kinesis-consumer/checkpoint"
"github.com/harlow/kinesis-consumer/checkpoint/redis"
)
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error
// WithClient the Kinesis client
func WithClient(client *kinesis.Kinesis) Option {
return func(c *Consumer) error {
c.svc = client
return nil
}
}
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint checkpoint.Checkpoint) Option {
return func(c *Consumer) error {
c.checkpoint = checkpoint
return nil
}
}
// WithLogger overrides the default logger
func WithLogger(logger log.Interface) Option {
return func(c *Consumer) error {
c.logger = logger
return nil
}
}
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(appName, streamName string, opts ...Option) (*Consumer, error) {
if appName == "" {
return nil, fmt.Errorf("must provide app name to consumer")
}
if streamName == "" {
return nil, fmt.Errorf("must provide stream name to consumer")
}
c := &Consumer{
appName: appName,
streamName: streamName,
}
// set options
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
}
}
// provide default logger
if c.logger == nil {
c.logger = log.Log.WithFields(log.Fields{
"package": "kinesis-consumer",
"app": appName,
"stream": streamName,
})
}
// provide a default kinesis client
if c.svc == nil {
c.svc = kinesis.New(session.New(aws.NewConfig()))
}
// provide default Redis checkpoint
if c.checkpoint == nil {
ck, err := redis.New(appName, streamName)
if err != nil {
return nil, err
}
c.checkpoint = ck
}
return c, nil
}
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
appName string
streamName string
svc *kinesis.Kinesis
logger log.Interface
checkpoint checkpoint.Checkpoint
}
// Scan scans each of the shards of the stream, calls the callback
// func with each of the kinesis records
func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
resp, err := c.svc.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(c.streamName),
},
)
if err != nil {
c.logger.WithError(err).Error("DescribeStream")
os.Exit(1)
}
var wg sync.WaitGroup
wg.Add(len(resp.StreamDescription.Shards))
// scan each of the shards
for _, shard := range resp.StreamDescription.Shards {
go func(shardID string) {
defer wg.Done()
c.ScanShard(ctx, shardID, fn)
cancel()
}(*shard.ShardId)
}
wg.Wait()
}
// ScanShard loops over records on a kinesis shard, call the callback func
// for each record and checkpoints after each page is processed
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
var (
logger = c.logger.WithFields(log.Fields{"shard": shardID})
sequenceNumber string
)
shardIterator, err := c.getShardIterator(shardID)
if err != nil {
logger.WithError(err).Error("getShardIterator")
return
}
logger.Info("scanning shard")
loop:
for {
select {
case <-ctx.Done():
break loop
default:
resp, err := c.svc.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil {
shardIterator, err = c.getShardIterator(shardID)
if err != nil {
logger.WithError(err).Error("getShardIterator")
break loop
}
continue
}
if len(resp.Records) > 0 {
for _, r := range resp.Records {
select {
case <-ctx.Done():
break loop
default:
sequenceNumber = *r.SequenceNumber
if ok := fn(r); !ok {
break loop
}
}
}
logger.WithField("records", len(resp.Records)).Info("checkpoint")
if err := c.checkpoint.Set(shardID, sequenceNumber); err != nil {
c.logger.WithError(err).Error("set checkpoint error")
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator, err = c.getShardIterator(shardID)
if err != nil {
logger.WithError(err).Error("getShardIterator")
break loop
}
} else {
shardIterator = resp.NextShardIterator
}
}
}
if sequenceNumber != "" {
if err := c.checkpoint.Set(shardID, sequenceNumber); err != nil {
c.logger.WithError(err).Error("set checkpoint error")
}
}
}
func (c *Consumer) getShardIterator(shardID string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(c.streamName),
}
seqNum, err := c.checkpoint.Get(shardID)
if err != nil {
return nil, err
}
if seqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(seqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.svc.GetShardIterator(params)
if err != nil {
c.logger.WithError(err).Error("GetShardIterator")
return nil, err
}
return resp.ShardIterator, nil
}