kinesis-consumer/consumer.go
Pierre Massat c04f3d8a94 Get a new shard iterator on error (#32)
* Get a new shard iterator on error
* Check for nil instead of empty string
* Get a new shard iterator on error
2017-10-15 17:40:30 -07:00

141 lines
3.2 KiB
Go

package connector
import (
"os"
"github.com/apex/log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
)
// NewConsumer creates a new consumer with initialied kinesis connection
func NewConsumer(config Config) *Consumer {
config.setDefaults()
svc := kinesis.New(
session.New(
aws.NewConfig().WithMaxRetries(10),
),
)
return &Consumer{
svc: svc,
Config: config,
}
}
type Consumer struct {
svc *kinesis.Kinesis
Config
}
// Start takes a handler and then loops over each of the shards
// processing each one with the handler.
func (c *Consumer) Start(handler Handler) {
resp, err := c.svc.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(c.StreamName),
},
)
if err != nil {
c.Logger.WithError(err).Error("DescribeStream")
os.Exit(1)
}
for _, shard := range resp.StreamDescription.Shards {
go c.handlerLoop(*shard.ShardId, handler)
}
}
func (c *Consumer) handlerLoop(shardID string, handler Handler) {
buf := &Buffer{
MaxRecordCount: c.BufferSize,
shardID: shardID,
}
ctx := c.Logger.WithFields(log.Fields{
"shard": shardID,
})
shardIterator := c.getShardIterator(shardID)
ctx.Info("processing")
for {
resp, err := c.svc.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil {
ctx.WithError(err).Error("GetRecords")
shardIterator = c.getShardIterator(shardID)
continue
}
if len(resp.Records) > 0 {
for _, r := range resp.Records {
buf.AddRecord(r)
if buf.ShouldFlush() {
handler.HandleRecords(*buf)
ctx.WithField("count", buf.RecordCount()).Info("flushed")
c.Checkpoint.SetCheckpoint(shardID, buf.LastSeq())
buf.Flush()
}
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator = c.getShardIterator(shardID)
} else {
shardIterator = resp.NextShardIterator
}
}
}
func (c *Consumer) getShardIterator(shardID string) *string {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(c.StreamName),
}
if c.Checkpoint.CheckpointExists(shardID) {
params.ShardIteratorType = aws.String(string(ShardIteratorAfterSequenceNumber))
params.StartingSequenceNumber = aws.String(c.Checkpoint.SequenceNumber())
} else {
params.ShardIteratorType = aws.String(string(c.ShardIteratorType))
}
resp, err := c.svc.GetShardIterator(params)
if err != nil {
c.Logger.WithError(err).Error("GetShardIterator")
os.Exit(1)
}
return resp.ShardIterator
}
func (c *Consumer) getShardIterator(shardID string) *string {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(c.StreamName),
}
if c.Checkpoint.CheckpointExists(shardID) {
params.ShardIteratorType = aws.String(string(ShardIteratorAfterSequenceNumber))
params.StartingSequenceNumber = aws.String(c.Checkpoint.SequenceNumber())
} else {
params.ShardIteratorType = aws.String(string(c.ShardIteratorType))
}
resp, err := c.svc.GetShardIterator(params)
if err != nil {
c.Logger.WithError(err).Error("GetShardIterator")
os.Exit(1)
}
return resp.ShardIterator
}