diff --git a/clientlibrary/config/config.go b/clientlibrary/config/config.go index 5b45678..2d50ca8 100644 --- a/clientlibrary/config/config.go +++ b/clientlibrary/config/config.go @@ -136,6 +136,9 @@ const ( // DefaultLeaseSyncingIntervalMillis Number of milliseconds to wait before syncing with lease table (dynamodDB) DefaultLeaseSyncingIntervalMillis = 60000 + + // DefaultMaxRetryCount The default maximum number of retries in case of error + DefaultMaxRetryCount = 5 ) type ( @@ -283,6 +286,9 @@ type ( // LeaseSyncingTimeInterval The number of milliseconds to wait before syncing with lease table (dynamoDB) LeaseSyncingTimeIntervalMillis int + + // MaxRetryCount The maximum number of retries in case of error + MaxRetryCount int } ) diff --git a/clientlibrary/config/kcl-config.go b/clientlibrary/config/kcl-config.go index 135f3fa..4d7181b 100644 --- a/clientlibrary/config/kcl-config.go +++ b/clientlibrary/config/kcl-config.go @@ -102,6 +102,7 @@ func NewKinesisClientLibConfigWithCredentials(applicationName, streamName, regio LeaseStealingIntervalMillis: DefaultLeaseStealingIntervalMillis, LeaseStealingClaimTimeoutMillis: DefaultLeaseStealingClaimTimeoutMillis, LeaseSyncingTimeIntervalMillis: DefaultLeaseSyncingIntervalMillis, + MaxRetryCount: DefaultMaxRetryCount, Logger: logger.GetDefaultLogger(), } } @@ -211,6 +212,13 @@ func (c *KinesisClientLibConfiguration) WithLogger(logger logger.Logger) *Kinesi return c } +// WithMaxRetryCount sets the max retry count in case of error. +func (c *KinesisClientLibConfiguration) WithMaxRetryCount(maxRetryCount int) *KinesisClientLibConfiguration { + checkIsValuePositive("maxRetryCount", maxRetryCount) + c.MaxRetryCount = maxRetryCount + return c +} + // WithMonitoringService sets the monitoring service to use to publish metrics. func (c *KinesisClientLibConfiguration) WithMonitoringService(mService metrics.MonitoringService) *KinesisClientLibConfiguration { // Nil case is handled downward (at worker creation) so no need to do it here. diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index cd4565a..b96b26d 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -44,14 +44,33 @@ import ( "github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics" ) +const ( + kinesisReadTPSLimit = 5 + MaxBytes = 10000000 + MaxBytesPerSecond = 2000000 + BytesToMbConversion = 1000000 +) + +var ( + rateLimitTimeNow = time.Now + rateLimitTimeSince = time.Since + localTPSExceededError = errors.New("Error GetRecords TPS Exceeded") + maxBytesExceededError = errors.New("Error GetRecords Max Bytes For Call Period Exceeded") +) + // PollingShardConsumer is responsible for polling data records from a (specified) shard. // Note: PollingShardConsumer only deal with one shard. type PollingShardConsumer struct { commonShardConsumer - streamName string - stop *chan struct{} - consumerID string - mService metrics.MonitoringService + streamName string + stop *chan struct{} + consumerID string + mService metrics.MonitoringService + currTime time.Time + callsLeft int + remBytes int + lastCheckTime time.Time + bytesRead int } func (sc *PollingShardConsumer) getShardIterator() (*string, error) { @@ -108,6 +127,12 @@ func (sc *PollingShardConsumer) getRecords() error { recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer) retriedErrors := 0 + // define API call rate limit starting window + sc.currTime = rateLimitTimeNow() + sc.callsLeft = kinesisReadTPSLimit + sc.bytesRead = 0 + sc.remBytes = MaxBytes + for { if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) { log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID) @@ -135,14 +160,47 @@ func (sc *PollingShardConsumer) getRecords() error { Limit: aws.Int32(int32(sc.kclConfig.MaxRecords)), ShardIterator: shardIterator, } - getResp, err := sc.callGetRecordsAPI(getRecordsArgs) + getResp, coolDownPeriod, err := sc.callGetRecordsAPI(getRecordsArgs) if err != nil { //aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling var throughputExceededErr *types.ProvisionedThroughputExceededException var kmsThrottlingErr *types.KMSThrottlingException - if errors.As(err, &throughputExceededErr) || errors.As(err, &kmsThrottlingErr) { + if errors.As(err, &throughputExceededErr) { + retriedErrors++ + if retriedErrors > sc.kclConfig.MaxRetryCount { + log.Errorf("message", "Throughput Exceeded Error: "+ + "reached max retry count getting records from shard", + "shardId", sc.shard.ID, + "retryCount", retriedErrors, + "error", err) + return err + } + // If there is insufficient provisioned throughput on the stream, + // subsequent calls made within the next 1 second throw ProvisionedThroughputExceededException. + // ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html + sc.waitASecond(sc.currTime) + continue + } + if err == localTPSExceededError { + sc.waitASecond(sc.currTime) + continue + } + if err == maxBytesExceededError { + time.Sleep(time.Duration(coolDownPeriod) * time.Second) + continue + } + if errors.As(err, &kmsThrottlingErr) { log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err) retriedErrors++ + // Greater than MaxRetryCount so we get the last retry + if retriedErrors > sc.kclConfig.MaxRetryCount { + log.Errorf("message", "KMS Throttling Error: "+ + "reached max retry count getting records from shard", + "shardId", sc.shard.ID, + "retryCount", retriedErrors, + "error", err) + return err + } // exponential backoff // https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.Errors.html#Programming.Errors.RetryAndBackoff time.Sleep(time.Duration(math.Exp2(float64(retriedErrors))*100) * time.Millisecond) @@ -182,7 +240,64 @@ func (sc *PollingShardConsumer) getRecords() error { } } -func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { - getResp, err := sc.kc.GetRecords(context.TODO(), gri) - return getResp, err +func (sc *PollingShardConsumer) waitASecond(timePassed time.Time) { + waitTime := time.Since(timePassed) + if waitTime < time.Second { + time.Sleep(time.Second - waitTime) + } +} + +func (sc *PollingShardConsumer) checkCoolOffPeriod() (int, error) { + // Each shard can support up to a maximum total data read rate of 2 MB per second via GetRecords. + // If a call to GetRecords returns 10 MB, subsequent calls made within the next 5 seconds throw an exception. + // ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html + // check for overspending of byte budget from getRecords call + currentTime := rateLimitTimeNow() + secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds() + sc.lastCheckTime = currentTime + sc.remBytes += int(secondsPassed * MaxBytesPerSecond) + transactionReadRate := float64(sc.bytesRead) / (secondsPassed * BytesToMbConversion) + + if sc.remBytes > MaxBytes { + sc.remBytes = MaxBytes + } + if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 { + // Wait until cool down period has passed to prevent ProvisionedThroughputExceededException + coolDown := sc.bytesRead / MaxBytesPerSecond + return coolDown, maxBytesExceededError + } else { + sc.remBytes -= sc.bytesRead + } + return 0, nil +} + +func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, int, error) { + if sc.bytesRead != 0 { + coolDownPeriod, err := sc.checkCoolOffPeriod() + if err != nil { + return nil, coolDownPeriod, err + } + } + // every new second, we get a fresh set of calls + if rateLimitTimeSince(sc.currTime) > time.Second { + sc.callsLeft = kinesisReadTPSLimit + sc.currTime = rateLimitTimeNow() + } + + if sc.callsLeft < 1 { + return nil, 0, localTPSExceededError + } + + getResp, err := sc.kc.GetRecords(context.TODO(), gri) + sc.callsLeft-- + // Calculate size of records from read transaction + sc.bytesRead = 0 + for _, record := range getResp.Records { + sc.bytesRead += len(record.Data) + } + if sc.lastCheckTime.IsZero() { + sc.lastCheckTime = rateLimitTimeNow() + } + + return getResp, 0, err } diff --git a/clientlibrary/worker/polling-shard-consumer_test.go b/clientlibrary/worker/polling-shard-consumer_test.go index b6a1fcf..68dffd0 100644 --- a/clientlibrary/worker/polling-shard-consumer_test.go +++ b/clientlibrary/worker/polling-shard-consumer_test.go @@ -22,6 +22,7 @@ package worker import ( "context" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kinesis" @@ -40,10 +41,119 @@ func TestCallGetRecordsAPI(t *testing.T) { gri := kinesis.GetRecordsInput{ ShardIterator: aws.String("shard-iterator-01"), } - out, err := psc.callGetRecordsAPI(&gri) + out, _, err := psc.callGetRecordsAPI(&gri) assert.Nil(t, err) assert.Equal(t, &ret, out) m1.AssertExpectations(t) + + // check that localTPSExceededError is thrown when trying more than 5 TPS + m2 := MockKinesisSubscriberGetter{} + psc2 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m2}, + callsLeft: 0, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 500 * time.Millisecond + } + out2, _, err2 := psc2.callGetRecordsAPI(&gri) + assert.Nil(t, out2) + assert.ErrorIs(t, err2, localTPSExceededError) + m2.AssertExpectations(t) + + // check that getRecords is called normally in bytesRead = 0 case + m3 := MockKinesisSubscriberGetter{} + ret3 := kinesis.GetRecordsOutput{} + m3.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret3, nil) + psc3 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m3}, + callsLeft: 2, + bytesRead: 0, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 2 * time.Second + } + out3, checkSleepVal, err3 := psc3.callGetRecordsAPI(&gri) + assert.Nil(t, err3) + assert.Equal(t, checkSleepVal, 0) + assert.Equal(t, &ret3, out3) + m3.AssertExpectations(t) + + // check that correct cool off period is taken for 10mb in 1 second + testTime := time.Now() + m4 := MockKinesisSubscriberGetter{} + psc4 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m4}, + callsLeft: 2, + bytesRead: MaxBytes, + lastCheckTime: testTime, + remBytes: MaxBytes, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 2 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime.Add(time.Second) + } + out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri) + assert.Nil(t, out4) + assert.Equal(t, maxBytesExceededError, err4) + m4.AssertExpectations(t) + if checkSleepVal2 != 5 { + t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2) + } + + // check that no cool off period is taken for 6mb in 3 seconds + testTime2 := time.Now() + m5 := MockKinesisSubscriberGetter{} + ret5 := kinesis.GetRecordsOutput{} + m5.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret5, nil) + psc5 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m5}, + callsLeft: 2, + bytesRead: MaxBytesPerSecond * 3, + lastCheckTime: testTime2, + remBytes: MaxBytes, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 3 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime2.Add(time.Second * 3) + } + out5, checkSleepVal3, err5 := psc5.callGetRecordsAPI(&gri) + assert.Nil(t, err5) + assert.Equal(t, checkSleepVal3, 0) + assert.Equal(t, &ret5, out5) + m5.AssertExpectations(t) + + // check for correct cool off period with 8mb in .2 seconds with 6mb remaining + testTime3 := time.Now() + m6 := MockKinesisSubscriberGetter{} + psc6 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m6}, + callsLeft: 2, + bytesRead: MaxBytesPerSecond * 4, + lastCheckTime: testTime3, + remBytes: MaxBytesPerSecond * 3, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 3 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime3.Add(time.Second / 5) + } + out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri) + assert.Nil(t, out6) + assert.Equal(t, err6, maxBytesExceededError) + m5.AssertExpectations(t) + if checkSleepVal4 != 4 { + t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4) + } + + // restore original func + rateLimitTimeNow = time.Now + rateLimitTimeSince = time.Since + } type MockKinesisSubscriberGetter struct {