diff --git a/pipeline.go b/pipeline.go index 8413125..b3cc44b 100644 --- a/pipeline.go +++ b/pipeline.go @@ -3,6 +3,7 @@ package connector import ( "log" "math" + "net" "reflect" "time" @@ -25,8 +26,34 @@ type Pipeline struct { CheckpointFilteredRecords bool } -var pipelineRecoverableErrorCodes = map[string]bool{ - "ProvisionedThroughputExceededException": true, +type pipelineIsRecoverableErrorFunc func(error) bool + +func pipelineKinesisIsRecoverableError(err error) bool { + recoverableErrorCodes := map[string]bool{ + "ProvisionedThroughputExceededException": true, + } + r := false + cErr, ok := err.(*kinesis.Error) + if ok && recoverableErrorCodes[cErr.Code] == true { + r = true + } + return r +} + +func pipelineNetIsRecoverableError(err error) bool { + recoverableErrors := map[string]bool{ + "connection reset by peer": true, + } + r := false + cErr, ok := err.(*net.OpError) + if ok && recoverableErrors[cErr.Err.Error()] == true { + r = true + } + return r +} + +var pipelineIsRecoverableErrors = []pipelineIsRecoverableErrorFunc{ + pipelineKinesisIsRecoverableError, pipelineNetIsRecoverableError, } // this determines whether the error is recoverable @@ -35,9 +62,11 @@ func (p Pipeline) isRecoverableError(err error) bool { log.Printf("isRecoverableError, type %s, value (+%v)\n", reflect.TypeOf(err).String(), err) - cErr, ok := err.(*kinesis.Error) - if ok && pipelineRecoverableErrorCodes[cErr.Code] == true { - r = true + for _, errF := range pipelineIsRecoverableErrors { + r = errF(err) + if r { + break + } } return r