From 7d6b1c33d06b8f4d35079072f1f17d337866ea2e Mon Sep 17 00:00:00 2001 From: Shiva Pentakota Date: Tue, 24 Jan 2023 16:28:22 -0800 Subject: [PATCH] fix: add maxBytes per second getRecord check Signed-off-by: Shiva Pentakota --- .../worker/polling-shard-consumer.go | 83 +++++++++++++--- .../worker/polling-shard-consumer_test.go | 99 ++++++++++++++++++- 2 files changed, 167 insertions(+), 15 deletions(-) diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index e207583..b96b26d 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -46,24 +46,31 @@ import ( const ( kinesisReadTPSLimit = 5 + MaxBytes = 10000000 + MaxBytesPerSecond = 2000000 + BytesToMbConversion = 1000000 ) var ( rateLimitTimeNow = time.Now rateLimitTimeSince = time.Since localTPSExceededError = errors.New("Error GetRecords TPS Exceeded") + maxBytesExceededError = errors.New("Error GetRecords Max Bytes For Call Period Exceeded") ) // PollingShardConsumer is responsible for polling data records from a (specified) shard. // Note: PollingShardConsumer only deal with one shard. type PollingShardConsumer struct { commonShardConsumer - streamName string - stop *chan struct{} - consumerID string - mService metrics.MonitoringService - currTime time.Time - callsLeft int + streamName string + stop *chan struct{} + consumerID string + mService metrics.MonitoringService + currTime time.Time + callsLeft int + remBytes int + lastCheckTime time.Time + bytesRead int } func (sc *PollingShardConsumer) getShardIterator() (*string, error) { @@ -123,6 +130,8 @@ func (sc *PollingShardConsumer) getRecords() error { // define API call rate limit starting window sc.currTime = rateLimitTimeNow() sc.callsLeft = kinesisReadTPSLimit + sc.bytesRead = 0 + sc.remBytes = MaxBytes for { if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) { @@ -151,15 +160,16 @@ func (sc *PollingShardConsumer) getRecords() error { Limit: aws.Int32(int32(sc.kclConfig.MaxRecords)), ShardIterator: shardIterator, } - getResp, err := sc.callGetRecordsAPI(getRecordsArgs) + getResp, coolDownPeriod, err := sc.callGetRecordsAPI(getRecordsArgs) if err != nil { //aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling var throughputExceededErr *types.ProvisionedThroughputExceededException var kmsThrottlingErr *types.KMSThrottlingException - if errors.As(err, &throughputExceededErr) || err == localTPSExceededError { + if errors.As(err, &throughputExceededErr) { retriedErrors++ if retriedErrors > sc.kclConfig.MaxRetryCount { - log.Errorf("message", "reached max retry count getting records from shard", + log.Errorf("message", "Throughput Exceeded Error: "+ + "reached max retry count getting records from shard", "shardId", sc.shard.ID, "retryCount", retriedErrors, "error", err) @@ -171,12 +181,21 @@ func (sc *PollingShardConsumer) getRecords() error { sc.waitASecond(sc.currTime) continue } + if err == localTPSExceededError { + sc.waitASecond(sc.currTime) + continue + } + if err == maxBytesExceededError { + time.Sleep(time.Duration(coolDownPeriod) * time.Second) + continue + } if errors.As(err, &kmsThrottlingErr) { log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err) retriedErrors++ // Greater than MaxRetryCount so we get the last retry if retriedErrors > sc.kclConfig.MaxRetryCount { - log.Errorf("message", "reached max retry count getting records from shard", + log.Errorf("message", "KMS Throttling Error: "+ + "reached max retry count getting records from shard", "shardId", sc.shard.ID, "retryCount", retriedErrors, "error", err) @@ -228,7 +247,37 @@ func (sc *PollingShardConsumer) waitASecond(timePassed time.Time) { } } -func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { +func (sc *PollingShardConsumer) checkCoolOffPeriod() (int, error) { + // Each shard can support up to a maximum total data read rate of 2 MB per second via GetRecords. + // If a call to GetRecords returns 10 MB, subsequent calls made within the next 5 seconds throw an exception. + // ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html + // check for overspending of byte budget from getRecords call + currentTime := rateLimitTimeNow() + secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds() + sc.lastCheckTime = currentTime + sc.remBytes += int(secondsPassed * MaxBytesPerSecond) + transactionReadRate := float64(sc.bytesRead) / (secondsPassed * BytesToMbConversion) + + if sc.remBytes > MaxBytes { + sc.remBytes = MaxBytes + } + if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 { + // Wait until cool down period has passed to prevent ProvisionedThroughputExceededException + coolDown := sc.bytesRead / MaxBytesPerSecond + return coolDown, maxBytesExceededError + } else { + sc.remBytes -= sc.bytesRead + } + return 0, nil +} + +func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, int, error) { + if sc.bytesRead != 0 { + coolDownPeriod, err := sc.checkCoolOffPeriod() + if err != nil { + return nil, coolDownPeriod, err + } + } // every new second, we get a fresh set of calls if rateLimitTimeSince(sc.currTime) > time.Second { sc.callsLeft = kinesisReadTPSLimit @@ -236,11 +285,19 @@ func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) } if sc.callsLeft < 1 { - return nil, localTPSExceededError + return nil, 0, localTPSExceededError } getResp, err := sc.kc.GetRecords(context.TODO(), gri) sc.callsLeft-- + // Calculate size of records from read transaction + sc.bytesRead = 0 + for _, record := range getResp.Records { + sc.bytesRead += len(record.Data) + } + if sc.lastCheckTime.IsZero() { + sc.lastCheckTime = rateLimitTimeNow() + } - return getResp, err + return getResp, 0, err } diff --git a/clientlibrary/worker/polling-shard-consumer_test.go b/clientlibrary/worker/polling-shard-consumer_test.go index 7819be7..68dffd0 100644 --- a/clientlibrary/worker/polling-shard-consumer_test.go +++ b/clientlibrary/worker/polling-shard-consumer_test.go @@ -41,7 +41,7 @@ func TestCallGetRecordsAPI(t *testing.T) { gri := kinesis.GetRecordsInput{ ShardIterator: aws.String("shard-iterator-01"), } - out, err := psc.callGetRecordsAPI(&gri) + out, _, err := psc.callGetRecordsAPI(&gri) assert.Nil(t, err) assert.Equal(t, &ret, out) m1.AssertExpectations(t) @@ -55,10 +55,105 @@ func TestCallGetRecordsAPI(t *testing.T) { rateLimitTimeSince = func(t time.Time) time.Duration { return 500 * time.Millisecond } - out2, err2 := psc2.callGetRecordsAPI(&gri) + out2, _, err2 := psc2.callGetRecordsAPI(&gri) assert.Nil(t, out2) assert.ErrorIs(t, err2, localTPSExceededError) m2.AssertExpectations(t) + + // check that getRecords is called normally in bytesRead = 0 case + m3 := MockKinesisSubscriberGetter{} + ret3 := kinesis.GetRecordsOutput{} + m3.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret3, nil) + psc3 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m3}, + callsLeft: 2, + bytesRead: 0, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 2 * time.Second + } + out3, checkSleepVal, err3 := psc3.callGetRecordsAPI(&gri) + assert.Nil(t, err3) + assert.Equal(t, checkSleepVal, 0) + assert.Equal(t, &ret3, out3) + m3.AssertExpectations(t) + + // check that correct cool off period is taken for 10mb in 1 second + testTime := time.Now() + m4 := MockKinesisSubscriberGetter{} + psc4 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m4}, + callsLeft: 2, + bytesRead: MaxBytes, + lastCheckTime: testTime, + remBytes: MaxBytes, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 2 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime.Add(time.Second) + } + out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri) + assert.Nil(t, out4) + assert.Equal(t, maxBytesExceededError, err4) + m4.AssertExpectations(t) + if checkSleepVal2 != 5 { + t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2) + } + + // check that no cool off period is taken for 6mb in 3 seconds + testTime2 := time.Now() + m5 := MockKinesisSubscriberGetter{} + ret5 := kinesis.GetRecordsOutput{} + m5.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret5, nil) + psc5 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m5}, + callsLeft: 2, + bytesRead: MaxBytesPerSecond * 3, + lastCheckTime: testTime2, + remBytes: MaxBytes, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 3 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime2.Add(time.Second * 3) + } + out5, checkSleepVal3, err5 := psc5.callGetRecordsAPI(&gri) + assert.Nil(t, err5) + assert.Equal(t, checkSleepVal3, 0) + assert.Equal(t, &ret5, out5) + m5.AssertExpectations(t) + + // check for correct cool off period with 8mb in .2 seconds with 6mb remaining + testTime3 := time.Now() + m6 := MockKinesisSubscriberGetter{} + psc6 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m6}, + callsLeft: 2, + bytesRead: MaxBytesPerSecond * 4, + lastCheckTime: testTime3, + remBytes: MaxBytesPerSecond * 3, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 3 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime3.Add(time.Second / 5) + } + out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri) + assert.Nil(t, out6) + assert.Equal(t, err6, maxBytesExceededError) + m5.AssertExpectations(t) + if checkSleepVal4 != 4 { + t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4) + } + + // restore original func + rateLimitTimeNow = time.Now + rateLimitTimeSince = time.Since + } type MockKinesisSubscriberGetter struct {