diff --git a/awsbackoff.go b/awsbackoff.go deleted file mode 100644 index bf3a964..0000000 --- a/awsbackoff.go +++ /dev/null @@ -1,16 +0,0 @@ -package connector - -import ( - "math" - "time" -) - -// AWS Exponential Backoff -// Wait up to 5 minutes based on the aws exponential backoff algorithm -// http://docs.aws.amazon.com/general/latest/gr/api-retries.html -func handleAwsWaitTimeExp(attempts int) { - if attempts > 0 { - waitTime := time.Duration(math.Min(100*math.Pow(2, float64(attempts)), 300000)) * time.Millisecond - time.Sleep(waitTime) - } -} diff --git a/consumer.go b/consumer.go index 8348c26..5847c55 100644 --- a/consumer.go +++ b/consumer.go @@ -18,7 +18,10 @@ var ( // NewConsumer creates a new kinesis connection and returns a // new consumer initialized with app and stream name func NewConsumer(appName, streamName string) *Consumer { - svc := kinesis.New(session.New()) + sess := session.New( + aws.NewConfig().WithMaxRetries(10), + ) + svc := kinesis.New(sess) return &Consumer{ appName: appName, @@ -52,19 +55,20 @@ func (c *Consumer) Set(option string, value interface{}) { } } +// Start takes a handler and then loops over each of the shards +// processing each one with the handler. func (c *Consumer) Start(handler Handler) { - params := &kinesis.DescribeStreamInput{ - StreamName: aws.String(c.streamName), - } + resp, err := c.svc.DescribeStream( + &kinesis.DescribeStreamInput{ + StreamName: aws.String(c.streamName), + }, + ) - // describe stream - resp, err := c.svc.DescribeStream(params) if err != nil { logger.Log("fatal", "DescribeStream", "msg", err.Error()) os.Exit(1) } - // handle shards for _, shard := range resp.StreamDescription.Shards { logger.Log("info", "processing", "stream", c.streamName, "shard", shard.ShardId) go c.handlerLoop(*shard.ShardId, handler) @@ -95,31 +99,18 @@ func (c *Consumer) handlerLoop(shardID string, handler Handler) { b := &Buffer{MaxBatchCount: maxBatchCount} shardIterator := resp.ShardIterator - errCount := 0 for { - // get records from stream resp, err := c.svc.GetRecords(&kinesis.GetRecordsInput{ ShardIterator: shardIterator, }) - // handle recoverable errors, else exit program if err != nil { awsErr, _ := err.(awserr.Error) - - if isRecoverableError(err) { - logger.Log("warn", "getRecords", "errorCount", errCount, "code", awsErr.Code()) - handleAwsWaitTimeExp(errCount) - errCount++ - } else { - logger.Log("fatal", "getRecords", awsErr.Code()) - os.Exit(1) - } - } else { - errCount = 0 + logger.Log("fatal", "getRecords", awsErr.Code()) + os.Exit(1) } - // process records if len(resp.Records) > 0 { for _, r := range resp.Records { b.AddRecord(r) diff --git a/errors.go b/errors.go deleted file mode 100644 index e1da838..0000000 --- a/errors.go +++ /dev/null @@ -1,67 +0,0 @@ -package connector - -import ( - "net" - "net/url" - - "github.com/aws/aws-sdk-go/aws/awserr" -) - -type isRecoverableErrorFunc func(error) bool - -var isRecoverableErrors = []isRecoverableErrorFunc{ - kinesisIsRecoverableError, - netIsRecoverableError, - urlIsRecoverableError, -} - -// isRecoverableError determines whether the error is recoverable -func isRecoverableError(err error) bool { - for _, errF := range isRecoverableErrors { - if errF(err) { - return true - } - } - - return false -} - -func kinesisIsRecoverableError(err error) bool { - recoverableErrorCodes := map[string]bool{ - "InternalFailure": true, - "ProvisionedThroughputExceededException": true, - "RequestError": true, - "ServiceUnavailable": true, - "Throttling": true, - } - - if err, ok := err.(awserr.Error); ok { - if ok && recoverableErrorCodes[err.Code()] == true { - return true - } - } - - return false -} - -func urlIsRecoverableError(err error) bool { - _, ok := err.(*url.Error) - if ok { - return true - } - - return false -} - -func netIsRecoverableError(err error) bool { - recoverableErrors := map[string]bool{ - "connection reset by peer": true, - } - - cErr, ok := err.(*net.OpError) - if ok && recoverableErrors[cErr.Err.Error()] == true { - return true - } - - return false -} diff --git a/errors_test.go b/errors_test.go deleted file mode 100644 index 06be6c1..0000000 --- a/errors_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package connector - -import ( - "fmt" - "net" - "testing" - - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/bmizerany/assert" -) - -func Test_isRecoverableError(t *testing.T) { - testCases := []struct { - err error - isRecoverable bool - }{ - {err: awserr.New("ProvisionedThroughputExceededException", "", nil), isRecoverable: true}, - {err: awserr.New("Throttling", "", nil), isRecoverable: true}, - {err: awserr.New("ServiceUnavailable", "", nil), isRecoverable: true}, - {err: awserr.New("ExpiredIteratorException", "", nil), isRecoverable: false}, - {err: &net.OpError{Err: fmt.Errorf("connection reset by peer")}, isRecoverable: true}, - {err: &net.OpError{Err: fmt.Errorf("unexpected error")}, isRecoverable: false}, - {err: fmt.Errorf("an arbitrary error"), isRecoverable: false}, - } - - for _, tc := range testCases { - isRecoverable := isRecoverableError(tc.err) - assert.Equal(t, isRecoverable, tc.isRecoverable) - } -}