kinesis-consumer/consumer.go
2018-12-28 21:19:39 -08:00

325 lines
8.8 KiB
Go

package consumer
import (
"context"
"fmt"
"io/ioutil"
"log"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
// Record is an alias of record returned from kinesis library
type Record = kinesis.Record
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer)
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint Checkpoint) Option {
return func(c *Consumer) {
c.checkpoint = checkpoint
}
}
// WithLogger overrides the default logger
func WithLogger(logger Logger) Option {
return func(c *Consumer) {
c.logger = logger
}
}
// WithCounter overrides the default counter
func WithCounter(counter Counter) Option {
return func(c *Consumer) {
c.counter = counter
}
}
// WithClient overrides the default client
func WithClient(client kinesisiface.KinesisAPI) Option {
return func(c *Consumer) {
c.client = client
}
}
// ScanStatus signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanStatus struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(streamName string, opts ...Option) (*Consumer, error) {
if streamName == "" {
return nil, fmt.Errorf("must provide stream name")
}
// new consumer with no-op checkpoint, counter, and logger
c := &Consumer{
streamName: streamName,
checkpoint: &noopCheckpoint{},
counter: &noopCounter{},
logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
},
}
// override defaults
for _, opt := range opts {
opt(c)
}
// default client if none provided
if c.client == nil {
newSession, err := session.NewSession(aws.NewConfig())
if err != nil {
return nil, err
}
c.client = kinesis.New(newSession)
}
return c, nil
}
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
streamName string
client kinesisiface.KinesisAPI
logger Logger
checkpoint Checkpoint
counter Counter
}
// Scan scans each of the shards of the stream, calls the callback
// func with each of the kinesis records.
func (c *Consumer) Scan(ctx context.Context, fn func(context.Context, *Record) ScanStatus) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.scan")
defer span.Finish()
// get shard ids
shardIDs, err := c.getShardIDs(ctx, c.streamName)
span.SetTag("stream.name", c.streamName)
span.SetTag("shard.count", len(shardIDs))
if err != nil {
span.LogKV("get shardID error", err.Error(), "stream.name", c.streamName)
ext.Error.Set(span, true)
return fmt.Errorf("get shards error: %s", err.Error())
}
if len(shardIDs) == 0 {
span.LogKV("stream.name", c.streamName, "shards.count", len(shardIDs))
ext.Error.Set(span, true)
return fmt.Errorf("no shards available")
}
var (
wg sync.WaitGroup
errc = make(chan error, 1)
)
wg.Add(len(shardIDs))
// process each shard in a separate goroutine
for _, shardID := range shardIDs {
go func(shardID string) {
defer wg.Done()
if err := c.ScanShard(ctx, shardID, fn); err != nil {
span.LogKV("scan shard error", err.Error(), "shardID", shardID)
ext.Error.Set(span, true)
span.Finish()
select {
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
// first error to occur
default:
// error has already occured
}
}
cancel()
}(shardID)
}
wg.Wait()
close(errc)
return <-errc
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(
ctx context.Context,
shardID string,
fn func(context.Context, *Record) ScanStatus,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.scanshard")
defer span.Finish()
// get checkpoint
lastSeqNum, err := c.checkpoint.Get(ctx, c.streamName, shardID)
if err != nil {
span.LogKV("checkpoint error", err.Error(), "shardID", shardID)
ext.Error.Set(span, true)
return fmt.Errorf("get checkpoint error: %v", err)
}
// get shard iterator
shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
span.LogKV("get shard error", err.Error(), "shardID", shardID, "lastSeqNumber", lastSeqNum)
ext.Error.Set(span, true)
return fmt.Errorf("get shard iterator error: %v", err)
}
c.logger.Log(fmt.Sprintf("scanning shardID %s lastSeqNum %s", shardID, lastSeqNum))
return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn)
}
func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(context.Context, *Record) ScanStatus) error {
span := opentracing.SpanFromContext(ctx)
for {
select {
case <-ctx.Done():
span.SetTag("scan", "done")
return nil
default:
span.SetTag("scan", "on")
resp, err := c.client.GetRecordsWithContext(ctx, &kinesis.GetRecordsInput{
ShardIterator: shardIterator,
})
if err != nil {
shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
ext.Error.Set(span, true)
span.LogKV("get shard iterator error", err.Error())
return fmt.Errorf("get shard iterator error: %v", err)
}
continue
}
// loop records of page
for _, r := range resp.Records {
ctx = opentracing.ContextWithSpan(ctx, span)
isScanStopped, err := c.handleRecord(ctx, shardID, r, fn)
if err != nil {
span.LogKV("handle record error", err.Error(), "shardID", shardID)
ext.Error.Set(span, true)
return err
}
if isScanStopped {
span.SetTag("scan", "stopped")
return nil
}
lastSeqNum = *r.SequenceNumber
}
if isShardClosed(resp.NextShardIterator, shardIterator) {
span.LogKV("is shard closed", "true")
return nil
}
shardIterator = resp.NextShardIterator
}
}
}
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}
func (c *Consumer) handleRecord(ctx context.Context, shardID string, r *Record, fn func(context.Context, *Record) ScanStatus) (isScanStopped bool, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.handleRecord")
defer span.Finish()
status := fn(ctx, r)
if !status.SkipCheckpoint {
span.LogKV("scan.state", status)
if err := c.checkpoint.Set(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
span.LogKV("checkpoint error", err.Error(), "stream.name", c.streamName, "shardID", shardID, "sequenceNumber", *r.SequenceNumber)
ext.Error.Set(span, true)
return false, err
}
}
if err := status.Error; err != nil {
span.LogKV("scan.state", status.Error)
ext.Error.Set(span, true)
return false, err
}
c.counter.Add("records", 1)
if status.StopScan {
span.LogKV("scan.state", "stopped")
return true, nil
}
return false, nil
}
func (c *Consumer) getShardIDs(ctx context.Context, streamName string) ([]string, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.getShardIDs")
defer span.Finish()
span = span.SetTag("streamName", streamName)
resp, err := c.client.DescribeStreamWithContext(ctx,
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
span.LogKV("describe stream error", err.Error())
ext.Error.Set(span, true)
return nil, fmt.Errorf("describe stream error: %v", err)
}
var ss []string
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}
return ss, nil
}
func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, lastSeqNum string) (*string, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.getShardIterator",
opentracing.Tag{Key: "streamName", Value: streamName},
opentracing.Tag{Key: "shardID", Value: shardID},
opentracing.Tag{Key: "lastSeqNum", Value: lastSeqNum})
defer span.Finish()
shard := aws.String(shardID)
stream := aws.String(streamName)
params := &kinesis.GetShardIteratorInput{
ShardId: shard,
StreamName: stream,
}
span = span.SetTag("shardID", shard)
span = span.SetTag("streamName", stream)
if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.client.GetShardIteratorWithContext(ctx, params)
if err != nil {
span.LogKV("get shard error", err.Error())
ext.Error.Set(span, true)
return nil, err
}
return resp.ShardIterator, nil
}