The Firehose service can take a max batch size of 500. While created the example the need for finer grained configuration was necessary.
143 lines
3.5 KiB
Go
143 lines
3.5 KiB
Go
package connector
|
|
|
|
import (
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
|
)
|
|
|
|
var (
|
|
pollInterval = 1 * time.Second
|
|
maxBatchCount = 1000
|
|
)
|
|
|
|
// NewConsumer creates a new kinesis connection and returns a
|
|
// new consumer initialized with app and stream name
|
|
func NewConsumer(appName, streamName string) *Consumer {
|
|
svc := kinesis.New(session.New())
|
|
|
|
return &Consumer{
|
|
appName: appName,
|
|
streamName: streamName,
|
|
svc: svc,
|
|
}
|
|
}
|
|
|
|
type Consumer struct {
|
|
appName string
|
|
streamName string
|
|
svc *kinesis.Kinesis
|
|
}
|
|
|
|
// Set `option` to `value`
|
|
func (c *Consumer) Set(option string, value interface{}) {
|
|
var err error
|
|
|
|
switch option {
|
|
case "maxBatchCount":
|
|
maxBatchCount = value.(int)
|
|
case "pollInterval":
|
|
pollInterval, err = time.ParseDuration(value.(string))
|
|
if err != nil {
|
|
logger.Log("fatal", "ParseDuration", "msg", "unable to parse pollInterval value")
|
|
os.Exit(1)
|
|
}
|
|
default:
|
|
logger.Log("fatal", "Set", "msg", "unknown option")
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func (c *Consumer) Start(handler Handler) {
|
|
params := &kinesis.DescribeStreamInput{
|
|
StreamName: aws.String(c.streamName),
|
|
}
|
|
|
|
// describe stream
|
|
resp, err := c.svc.DescribeStream(params)
|
|
if err != nil {
|
|
logger.Log("fatal", "DescribeStream", "msg", err.Error())
|
|
os.Exit(1)
|
|
}
|
|
|
|
// handle shards
|
|
for _, shard := range resp.StreamDescription.Shards {
|
|
logger.Log("info", "processing", "stream", c.streamName, "shard", shard.ShardId)
|
|
go c.handlerLoop(*shard.ShardId, handler)
|
|
}
|
|
}
|
|
|
|
func (c *Consumer) handlerLoop(shardID string, handler Handler) {
|
|
params := &kinesis.GetShardIteratorInput{
|
|
ShardId: aws.String(shardID),
|
|
StreamName: aws.String(c.streamName),
|
|
}
|
|
|
|
checkpoint := &Checkpoint{AppName: c.appName, StreamName: c.streamName}
|
|
if checkpoint.CheckpointExists(shardID) {
|
|
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
|
|
params.StartingSequenceNumber = aws.String(checkpoint.SequenceNumber())
|
|
} else {
|
|
params.ShardIteratorType = aws.String("TRIM_HORIZON")
|
|
}
|
|
|
|
resp, err := c.svc.GetShardIterator(params)
|
|
if err != nil {
|
|
if awsErr, ok := err.(awserr.Error); ok {
|
|
logger.Log("fatal", "getShardIterator", "code", awsErr.Code(), "msg", awsErr.Message(), "origError", awsErr.OrigErr())
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
b := &Buffer{MaxBatchCount: maxBatchCount}
|
|
shardIterator := resp.ShardIterator
|
|
errCount := 0
|
|
|
|
for {
|
|
// get records from stream
|
|
resp, err := c.svc.GetRecords(&kinesis.GetRecordsInput{
|
|
ShardIterator: shardIterator,
|
|
})
|
|
|
|
// handle recoverable errors, else exit program
|
|
if err != nil {
|
|
awsErr, _ := err.(awserr.Error)
|
|
|
|
if isRecoverableError(err) {
|
|
logger.Log("warn", "getRecords", "errorCount", errCount, "code", awsErr.Code())
|
|
handleAwsWaitTimeExp(errCount)
|
|
errCount++
|
|
} else {
|
|
logger.Log("fatal", "getRecords", awsErr.Code())
|
|
os.Exit(1)
|
|
}
|
|
} else {
|
|
errCount = 0
|
|
}
|
|
|
|
// process records
|
|
if len(resp.Records) > 0 {
|
|
for _, r := range resp.Records {
|
|
b.AddRecord(r)
|
|
|
|
if b.ShouldFlush() {
|
|
handler.HandleRecords(*b)
|
|
checkpoint.SetCheckpoint(shardID, b.LastSeq())
|
|
b.Flush()
|
|
}
|
|
}
|
|
} else if resp.NextShardIterator == aws.String("") || shardIterator == resp.NextShardIterator {
|
|
logger.Log("fatal", "nextShardIterator", "msg", err.Error())
|
|
os.Exit(1)
|
|
} else {
|
|
logger.Log("info", "sleeping", "msg", "no records to process")
|
|
time.Sleep(pollInterval)
|
|
}
|
|
|
|
shardIterator = resp.NextShardIterator
|
|
}
|
|
}
|