Leverage the default AWS retry logic
This commit is contained in:
parent
e150d4832b
commit
dded9d0a0e
4 changed files with 13 additions and 135 deletions
|
|
@ -1,16 +0,0 @@
|
|||
package connector
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AWS Exponential Backoff
|
||||
// Wait up to 5 minutes based on the aws exponential backoff algorithm
|
||||
// http://docs.aws.amazon.com/general/latest/gr/api-retries.html
|
||||
func handleAwsWaitTimeExp(attempts int) {
|
||||
if attempts > 0 {
|
||||
waitTime := time.Duration(math.Min(100*math.Pow(2, float64(attempts)), 300000)) * time.Millisecond
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
35
consumer.go
35
consumer.go
|
|
@ -18,7 +18,10 @@ var (
|
|||
// NewConsumer creates a new kinesis connection and returns a
|
||||
// new consumer initialized with app and stream name
|
||||
func NewConsumer(appName, streamName string) *Consumer {
|
||||
svc := kinesis.New(session.New())
|
||||
sess := session.New(
|
||||
aws.NewConfig().WithMaxRetries(10),
|
||||
)
|
||||
svc := kinesis.New(sess)
|
||||
|
||||
return &Consumer{
|
||||
appName: appName,
|
||||
|
|
@ -52,19 +55,20 @@ func (c *Consumer) Set(option string, value interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
// Start takes a handler and then loops over each of the shards
|
||||
// processing each one with the handler.
|
||||
func (c *Consumer) Start(handler Handler) {
|
||||
params := &kinesis.DescribeStreamInput{
|
||||
StreamName: aws.String(c.streamName),
|
||||
}
|
||||
resp, err := c.svc.DescribeStream(
|
||||
&kinesis.DescribeStreamInput{
|
||||
StreamName: aws.String(c.streamName),
|
||||
},
|
||||
)
|
||||
|
||||
// describe stream
|
||||
resp, err := c.svc.DescribeStream(params)
|
||||
if err != nil {
|
||||
logger.Log("fatal", "DescribeStream", "msg", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// handle shards
|
||||
for _, shard := range resp.StreamDescription.Shards {
|
||||
logger.Log("info", "processing", "stream", c.streamName, "shard", shard.ShardId)
|
||||
go c.handlerLoop(*shard.ShardId, handler)
|
||||
|
|
@ -95,31 +99,18 @@ func (c *Consumer) handlerLoop(shardID string, handler Handler) {
|
|||
|
||||
b := &Buffer{MaxBatchCount: maxBatchCount}
|
||||
shardIterator := resp.ShardIterator
|
||||
errCount := 0
|
||||
|
||||
for {
|
||||
// get records from stream
|
||||
resp, err := c.svc.GetRecords(&kinesis.GetRecordsInput{
|
||||
ShardIterator: shardIterator,
|
||||
})
|
||||
|
||||
// handle recoverable errors, else exit program
|
||||
if err != nil {
|
||||
awsErr, _ := err.(awserr.Error)
|
||||
|
||||
if isRecoverableError(err) {
|
||||
logger.Log("warn", "getRecords", "errorCount", errCount, "code", awsErr.Code())
|
||||
handleAwsWaitTimeExp(errCount)
|
||||
errCount++
|
||||
} else {
|
||||
logger.Log("fatal", "getRecords", awsErr.Code())
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
errCount = 0
|
||||
logger.Log("fatal", "getRecords", awsErr.Code())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// process records
|
||||
if len(resp.Records) > 0 {
|
||||
for _, r := range resp.Records {
|
||||
b.AddRecord(r)
|
||||
|
|
|
|||
67
errors.go
67
errors.go
|
|
@ -1,67 +0,0 @@
|
|||
package connector
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
type isRecoverableErrorFunc func(error) bool
|
||||
|
||||
var isRecoverableErrors = []isRecoverableErrorFunc{
|
||||
kinesisIsRecoverableError,
|
||||
netIsRecoverableError,
|
||||
urlIsRecoverableError,
|
||||
}
|
||||
|
||||
// isRecoverableError determines whether the error is recoverable
|
||||
func isRecoverableError(err error) bool {
|
||||
for _, errF := range isRecoverableErrors {
|
||||
if errF(err) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func kinesisIsRecoverableError(err error) bool {
|
||||
recoverableErrorCodes := map[string]bool{
|
||||
"InternalFailure": true,
|
||||
"ProvisionedThroughputExceededException": true,
|
||||
"RequestError": true,
|
||||
"ServiceUnavailable": true,
|
||||
"Throttling": true,
|
||||
}
|
||||
|
||||
if err, ok := err.(awserr.Error); ok {
|
||||
if ok && recoverableErrorCodes[err.Code()] == true {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func urlIsRecoverableError(err error) bool {
|
||||
_, ok := err.(*url.Error)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func netIsRecoverableError(err error) bool {
|
||||
recoverableErrors := map[string]bool{
|
||||
"connection reset by peer": true,
|
||||
}
|
||||
|
||||
cErr, ok := err.(*net.OpError)
|
||||
if ok && recoverableErrors[cErr.Err.Error()] == true {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
package connector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/bmizerany/assert"
|
||||
)
|
||||
|
||||
func Test_isRecoverableError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
err error
|
||||
isRecoverable bool
|
||||
}{
|
||||
{err: awserr.New("ProvisionedThroughputExceededException", "", nil), isRecoverable: true},
|
||||
{err: awserr.New("Throttling", "", nil), isRecoverable: true},
|
||||
{err: awserr.New("ServiceUnavailable", "", nil), isRecoverable: true},
|
||||
{err: awserr.New("ExpiredIteratorException", "", nil), isRecoverable: false},
|
||||
{err: &net.OpError{Err: fmt.Errorf("connection reset by peer")}, isRecoverable: true},
|
||||
{err: &net.OpError{Err: fmt.Errorf("unexpected error")}, isRecoverable: false},
|
||||
{err: fmt.Errorf("an arbitrary error"), isRecoverable: false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
isRecoverable := isRecoverableError(tc.err)
|
||||
assert.Equal(t, isRecoverable, tc.isRecoverable)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue