2017-11-20 16:21:40 +00:00
|
|
|
package consumer
|
2016-02-03 05:04:22 +00:00
|
|
|
|
|
|
|
|
import (
|
2017-11-20 16:21:40 +00:00
|
|
|
"context"
|
|
|
|
|
"fmt"
|
2017-11-22 18:46:39 +00:00
|
|
|
"io/ioutil"
|
|
|
|
|
"log"
|
2017-11-20 16:21:40 +00:00
|
|
|
"sync"
|
2016-02-03 05:04:22 +00:00
|
|
|
|
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
|
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
|
|
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
2017-11-20 16:21:40 +00:00
|
|
|
"github.com/harlow/kinesis-consumer/checkpoint"
|
2016-02-03 05:04:22 +00:00
|
|
|
)
|
|
|
|
|
|
2017-11-21 16:58:16 +00:00
|
|
|
type Record = kinesis.Record
|
|
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
// Option is used to override defaults when creating a new Consumer
|
|
|
|
|
type Option func(*Consumer) error
|
2016-05-01 05:23:35 +00:00
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
// WithCheckpoint overrides the default checkpoint
|
|
|
|
|
func WithCheckpoint(checkpoint checkpoint.Checkpoint) Option {
|
|
|
|
|
return func(c *Consumer) error {
|
|
|
|
|
c.checkpoint = checkpoint
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// WithLogger overrides the default logger
|
2017-11-22 18:46:39 +00:00
|
|
|
func WithLogger(logger *log.Logger) Option {
|
2017-11-20 16:21:40 +00:00
|
|
|
return func(c *Consumer) error {
|
|
|
|
|
c.logger = logger
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// New creates a kinesis consumer with default settings. Use Option to override
|
|
|
|
|
// any of the optional attributes.
|
2017-11-21 16:58:16 +00:00
|
|
|
func New(checkpoint checkpoint.Checkpoint, app, stream string, opts ...Option) (*Consumer, error) {
|
|
|
|
|
if checkpoint == nil {
|
|
|
|
|
return nil, fmt.Errorf("must provide checkpoint")
|
2017-11-20 16:21:40 +00:00
|
|
|
}
|
|
|
|
|
|
2017-11-21 16:58:16 +00:00
|
|
|
if app == "" {
|
|
|
|
|
return nil, fmt.Errorf("must provide app name")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if stream == "" {
|
|
|
|
|
return nil, fmt.Errorf("must provide stream name")
|
2017-11-20 16:21:40 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c := &Consumer{
|
2017-11-21 16:58:16 +00:00
|
|
|
checkpoint: checkpoint,
|
2017-11-22 18:46:39 +00:00
|
|
|
appName: app,
|
2017-11-21 16:58:16 +00:00
|
|
|
streamName: stream,
|
2017-11-20 16:21:40 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set options
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
|
if err := opt(c); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// provide default logger
|
|
|
|
|
if c.logger == nil {
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger = log.New(ioutil.Discard, "kinesis-consumer: ", log.LstdFlags)
|
2017-11-20 16:21:40 +00:00
|
|
|
}
|
2016-02-03 05:04:22 +00:00
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
// provide a default kinesis client
|
2017-11-21 16:58:16 +00:00
|
|
|
if c.client == nil {
|
|
|
|
|
c.client = kinesis.New(session.New(aws.NewConfig()))
|
2017-11-20 16:21:40 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return c, nil
|
2016-02-03 05:04:22 +00:00
|
|
|
}
|
|
|
|
|
|
2017-10-16 00:45:38 +00:00
|
|
|
// Consumer wraps the interaction with the Kinesis stream
|
2016-02-03 05:04:22 +00:00
|
|
|
type Consumer struct {
|
2017-11-22 18:46:39 +00:00
|
|
|
appName string
|
2017-11-20 16:21:40 +00:00
|
|
|
streamName string
|
2017-11-21 16:58:16 +00:00
|
|
|
client *kinesis.Kinesis
|
2017-11-22 18:46:39 +00:00
|
|
|
logger *log.Logger
|
2017-11-20 16:21:40 +00:00
|
|
|
checkpoint checkpoint.Checkpoint
|
2016-05-01 05:23:35 +00:00
|
|
|
}
|
|
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
// Scan scans each of the shards of the stream, calls the callback
|
2017-11-20 19:45:30 +00:00
|
|
|
// func with each of the kinesis records.
|
|
|
|
|
func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) error {
|
2017-11-20 16:21:40 +00:00
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
2017-11-20 19:45:30 +00:00
|
|
|
// grab the stream details
|
2017-11-21 16:58:16 +00:00
|
|
|
resp, err := c.client.DescribeStream(
|
2016-05-01 00:04:44 +00:00
|
|
|
&kinesis.DescribeStreamInput{
|
2017-11-20 16:21:40 +00:00
|
|
|
StreamName: aws.String(c.streamName),
|
2016-05-01 00:04:44 +00:00
|
|
|
},
|
|
|
|
|
)
|
2016-02-03 05:04:22 +00:00
|
|
|
if err != nil {
|
2017-11-20 19:45:30 +00:00
|
|
|
return err
|
2016-02-03 05:04:22 +00:00
|
|
|
}
|
|
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
var wg sync.WaitGroup
|
|
|
|
|
wg.Add(len(resp.StreamDescription.Shards))
|
|
|
|
|
|
2017-11-20 19:45:30 +00:00
|
|
|
// launch goroutine to process each of the shards
|
2016-02-03 05:04:22 +00:00
|
|
|
for _, shard := range resp.StreamDescription.Shards {
|
2017-11-20 16:21:40 +00:00
|
|
|
go func(shardID string) {
|
|
|
|
|
defer wg.Done()
|
|
|
|
|
c.ScanShard(ctx, shardID, fn)
|
|
|
|
|
cancel()
|
|
|
|
|
}(*shard.ShardId)
|
2016-02-03 05:04:22 +00:00
|
|
|
}
|
2017-11-20 16:21:40 +00:00
|
|
|
|
|
|
|
|
wg.Wait()
|
2017-11-20 19:45:30 +00:00
|
|
|
return nil
|
2016-02-03 05:04:22 +00:00
|
|
|
}
|
|
|
|
|
|
2017-11-20 19:45:30 +00:00
|
|
|
// ScanShard loops over records on a specific shard, calls the callback func
|
|
|
|
|
// for each record and checkpoints after each page is processed.
|
|
|
|
|
// Note: returning `false` from the callback func will end the scan.
|
2017-11-20 16:21:40 +00:00
|
|
|
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
|
2017-11-21 16:58:16 +00:00
|
|
|
lastSeqNum, err := c.checkpoint.Get(shardID)
|
|
|
|
|
if err != nil {
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Printf("get checkpoint error: %v", err)
|
2017-11-21 16:58:16 +00:00
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
|
2017-11-20 16:21:40 +00:00
|
|
|
if err != nil {
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Printf("get shard iterator error: %v", err)
|
2017-11-20 16:21:40 +00:00
|
|
|
return
|
2016-05-01 01:05:04 +00:00
|
|
|
}
|
2016-05-08 01:05:52 +00:00
|
|
|
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Println("scanning", shardID)
|
2016-02-03 05:04:22 +00:00
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
loop:
|
2016-02-03 05:04:22 +00:00
|
|
|
for {
|
2017-11-20 16:21:40 +00:00
|
|
|
select {
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
break loop
|
|
|
|
|
default:
|
2017-11-21 16:58:16 +00:00
|
|
|
resp, err := c.client.GetRecords(
|
2017-11-20 16:21:40 +00:00
|
|
|
&kinesis.GetRecordsInput{
|
|
|
|
|
ShardIterator: shardIterator,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
2017-11-21 16:58:16 +00:00
|
|
|
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
2017-11-20 16:21:40 +00:00
|
|
|
if err != nil {
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Printf("get shard iterator error: %v", err)
|
2017-11-20 16:21:40 +00:00
|
|
|
break loop
|
|
|
|
|
}
|
|
|
|
|
continue
|
|
|
|
|
}
|
2016-02-03 05:04:22 +00:00
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
if len(resp.Records) > 0 {
|
|
|
|
|
for _, r := range resp.Records {
|
|
|
|
|
select {
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
break loop
|
|
|
|
|
default:
|
2017-11-21 16:58:16 +00:00
|
|
|
lastSeqNum = *r.SequenceNumber
|
2017-11-20 16:21:40 +00:00
|
|
|
if ok := fn(r); !ok {
|
|
|
|
|
break loop
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2016-02-03 05:04:22 +00:00
|
|
|
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Println("checkpointing", shardID, len(resp.Records))
|
2017-11-21 16:58:16 +00:00
|
|
|
if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Printf("set checkpoint error: %v", err)
|
2017-11-20 17:37:30 +00:00
|
|
|
}
|
2017-11-20 16:21:40 +00:00
|
|
|
}
|
2016-02-03 05:04:22 +00:00
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
|
2017-11-21 16:58:16 +00:00
|
|
|
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
2017-11-20 16:21:40 +00:00
|
|
|
if err != nil {
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Printf("get shard iterator error: %v", err)
|
2017-11-20 16:21:40 +00:00
|
|
|
break loop
|
2016-02-08 21:21:54 +00:00
|
|
|
}
|
2017-11-20 16:21:40 +00:00
|
|
|
} else {
|
|
|
|
|
shardIterator = resp.NextShardIterator
|
2016-02-03 05:04:22 +00:00
|
|
|
}
|
|
|
|
|
}
|
2017-11-20 16:21:40 +00:00
|
|
|
}
|
2016-02-03 05:04:22 +00:00
|
|
|
|
2017-11-21 16:58:16 +00:00
|
|
|
if lastSeqNum == "" {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Println("checkpointing", shardID)
|
2017-11-21 16:58:16 +00:00
|
|
|
if err := c.checkpoint.Set(shardID, lastSeqNum); err != nil {
|
2017-11-22 18:46:39 +00:00
|
|
|
c.logger.Printf("set checkpoint error: %v", err)
|
2017-10-16 00:40:30 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2017-11-21 16:58:16 +00:00
|
|
|
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {
|
2017-10-16 00:40:30 +00:00
|
|
|
params := &kinesis.GetShardIteratorInput{
|
|
|
|
|
ShardId: aws.String(shardID),
|
2017-11-20 16:21:40 +00:00
|
|
|
StreamName: aws.String(c.streamName),
|
2017-10-16 00:40:30 +00:00
|
|
|
}
|
|
|
|
|
|
2017-11-21 16:58:16 +00:00
|
|
|
if lastSeqNum != "" {
|
2017-10-16 00:45:38 +00:00
|
|
|
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
|
2017-11-21 16:58:16 +00:00
|
|
|
params.StartingSequenceNumber = aws.String(lastSeqNum)
|
2017-10-16 00:40:30 +00:00
|
|
|
} else {
|
2017-10-16 00:45:38 +00:00
|
|
|
params.ShardIteratorType = aws.String("TRIM_HORIZON")
|
2017-10-16 00:40:30 +00:00
|
|
|
}
|
|
|
|
|
|
2017-11-21 16:58:16 +00:00
|
|
|
resp, err := c.client.GetShardIterator(params)
|
2017-10-16 00:40:30 +00:00
|
|
|
if err != nil {
|
2017-11-20 16:21:40 +00:00
|
|
|
return nil, err
|
2016-02-03 05:04:22 +00:00
|
|
|
}
|
2017-10-16 00:40:30 +00:00
|
|
|
|
2017-11-20 16:21:40 +00:00
|
|
|
return resp.ShardIterator, nil
|
2016-02-03 05:04:22 +00:00
|
|
|
}
|