From b5515931d102b49011ab29e805e6193eca4487eb Mon Sep 17 00:00:00 2001 From: Shiva Pentakota Date: Tue, 24 Jan 2023 11:59:32 -0800 Subject: [PATCH] fix: add hard cap maxRetries for getRecord errors Signed-off-by: Shiva Pentakota --- clientlibrary/config/config.go | 6 ++++++ clientlibrary/config/kcl-config.go | 8 ++++++++ clientlibrary/worker/polling-shard-consumer.go | 16 ++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/clientlibrary/config/config.go b/clientlibrary/config/config.go index 5b45678..2d50ca8 100644 --- a/clientlibrary/config/config.go +++ b/clientlibrary/config/config.go @@ -136,6 +136,9 @@ const ( // DefaultLeaseSyncingIntervalMillis Number of milliseconds to wait before syncing with lease table (dynamodDB) DefaultLeaseSyncingIntervalMillis = 60000 + + // DefaultMaxRetryCount The default maximum number of retries in case of error + DefaultMaxRetryCount = 5 ) type ( @@ -283,6 +286,9 @@ type ( // LeaseSyncingTimeInterval The number of milliseconds to wait before syncing with lease table (dynamoDB) LeaseSyncingTimeIntervalMillis int + + // MaxRetryCount The maximum number of retries in case of error + MaxRetryCount int } ) diff --git a/clientlibrary/config/kcl-config.go b/clientlibrary/config/kcl-config.go index 135f3fa..4d7181b 100644 --- a/clientlibrary/config/kcl-config.go +++ b/clientlibrary/config/kcl-config.go @@ -102,6 +102,7 @@ func NewKinesisClientLibConfigWithCredentials(applicationName, streamName, regio LeaseStealingIntervalMillis: DefaultLeaseStealingIntervalMillis, LeaseStealingClaimTimeoutMillis: DefaultLeaseStealingClaimTimeoutMillis, LeaseSyncingTimeIntervalMillis: DefaultLeaseSyncingIntervalMillis, + MaxRetryCount: DefaultMaxRetryCount, Logger: logger.GetDefaultLogger(), } } @@ -211,6 +212,13 @@ func (c *KinesisClientLibConfiguration) WithLogger(logger logger.Logger) *Kinesi return c } +// WithMaxRetryCount sets the max retry count in case of error. +func (c *KinesisClientLibConfiguration) WithMaxRetryCount(maxRetryCount int) *KinesisClientLibConfiguration { + checkIsValuePositive("maxRetryCount", maxRetryCount) + c.MaxRetryCount = maxRetryCount + return c +} + // WithMonitoringService sets the monitoring service to use to publish metrics. func (c *KinesisClientLibConfiguration) WithMonitoringService(mService metrics.MonitoringService) *KinesisClientLibConfiguration { // Nil case is handled downward (at worker creation) so no need to do it here. diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index ec973f5..e207583 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -157,6 +157,14 @@ func (sc *PollingShardConsumer) getRecords() error { var throughputExceededErr *types.ProvisionedThroughputExceededException var kmsThrottlingErr *types.KMSThrottlingException if errors.As(err, &throughputExceededErr) || err == localTPSExceededError { + retriedErrors++ + if retriedErrors > sc.kclConfig.MaxRetryCount { + log.Errorf("message", "reached max retry count getting records from shard", + "shardId", sc.shard.ID, + "retryCount", retriedErrors, + "error", err) + return err + } // If there is insufficient provisioned throughput on the stream, // subsequent calls made within the next 1 second throw ProvisionedThroughputExceededException. // ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html @@ -166,6 +174,14 @@ func (sc *PollingShardConsumer) getRecords() error { if errors.As(err, &kmsThrottlingErr) { log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err) retriedErrors++ + // Greater than MaxRetryCount so we get the last retry + if retriedErrors > sc.kclConfig.MaxRetryCount { + log.Errorf("message", "reached max retry count getting records from shard", + "shardId", sc.shard.ID, + "retryCount", retriedErrors, + "error", err) + return err + } // exponential backoff // https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.Errors.html#Programming.Errors.RetryAndBackoff time.Sleep(time.Duration(math.Exp2(float64(retriedErrors))*100) * time.Millisecond)