From ae3763e4787727d788c288e66f478cee1febaddf Mon Sep 17 00:00:00 2001 From: Shiva Pentakota Date: Mon, 23 Jan 2023 20:50:53 -0800 Subject: [PATCH] fix: added callGetRecordsAPI tests Signed-off-by: Shiva Pentakota --- .../worker/polling-shard-consumer.go | 156 ++++++++++-------- .../worker/polling-shard-consumer_test.go | 128 ++++++++++++++ 2 files changed, 218 insertions(+), 66 deletions(-) diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index 06e7d78..6c494a6 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -45,19 +45,33 @@ import ( ) const ( - MaxBytes = 10000000.0 - MaxBytesPerSecond = 2000000.0 - MaxReadTransactionsPerSecond = 5 + kinesisReadTPSLimit = 5 + MaxBytes = 10000000.0 + MaxBytesPerSecond = 2000000.0 + BytesToMbConversion = 1000000.0 +) + +var ( + rateLimitTimeNow = time.Now + rateLimitTimeSince = time.Since + rateLimitSleep = time.Sleep + + localTPSExceededError = errors.New("Error GetRecords TPS Exceeded") ) // PollingShardConsumer is responsible for polling data records from a (specified) shard. // Note: PollingShardConsumer only deal with one shard. type PollingShardConsumer struct { commonShardConsumer - streamName string - stop *chan struct{} - consumerID string - mService metrics.MonitoringService + streamName string + stop *chan struct{} + consumerID string + mService metrics.MonitoringService + currTime time.Time + callsLeft int + remBytes float64 + lastCheckTime time.Time + bytesRead float64 } func (sc *PollingShardConsumer) getShardIterator() (*string, error) { @@ -113,10 +127,12 @@ func (sc *PollingShardConsumer) getRecords() error { recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer) retriedErrors := 0 - transactionNum := 0 - remBytes := MaxBytes - var lastCheckTime time.Time - var firstTransactionTime time.Time + + // define API call rate limit starting window + sc.currTime = rateLimitTimeNow() + sc.callsLeft = kinesisReadTPSLimit + sc.bytesRead = 0 + sc.remBytes = MaxBytes for { if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) { @@ -139,30 +155,18 @@ func (sc *PollingShardConsumer) getRecords() error { getRecordsStartTime := time.Now() log.Debugf("Trying to read %d record from iterator: %v", sc.kclConfig.MaxRecords, aws.ToString(shardIterator)) + + // Get records from stream and retry as needed getRecordsArgs := &kinesis.GetRecordsInput{ Limit: aws.Int32(int32(sc.kclConfig.MaxRecords)), ShardIterator: shardIterator, } - - // Each shard can support up to five read transactions per second. - if transactionNum > MaxReadTransactionsPerSecond { - transactionNum = 0 - timeDiff := time.Since(firstTransactionTime) - if timeDiff < time.Second { - time.Sleep(timeDiff) - } - } - - // Get records from stream and retry as needed - // Each read transaction can provide up to 10,000 records with an upper quota of 10 MB per transaction. - // ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html - getResp, err := sc.kc.GetRecords(context.TODO(), getRecordsArgs) - getRecordsTransactionTime := time.Now() + getResp, err := sc.callGetRecordsAPI(getRecordsArgs) if err != nil { //aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling var throughputExceededErr *types.ProvisionedThroughputExceededException var kmsThrottlingErr *types.KMSThrottlingException - if errors.As(err, &throughputExceededErr) { + if errors.As(err, &throughputExceededErr) || err == localTPSExceededError { retriedErrors++ if retriedErrors > sc.kclConfig.MaxRetryCount { log.Errorf("message", "reached max retry count getting records from shard", @@ -174,10 +178,7 @@ func (sc *PollingShardConsumer) getRecords() error { // If there is insufficient provisioned throughput on the stream, // subsequent calls made within the next 1 second throw ProvisionedThroughputExceededException. // ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html - waitTime := time.Since(getRecordsTransactionTime) - if waitTime < time.Second { - time.Sleep(time.Second - waitTime) - } + sc.waitASecond(sc.currTime) continue } if errors.As(err, &kmsThrottlingErr) { @@ -196,26 +197,13 @@ func (sc *PollingShardConsumer) getRecords() error { time.Sleep(time.Duration(math.Exp2(float64(retriedErrors))*100) * time.Millisecond) continue } + log.Errorf("Error getting records from Kinesis that cannot be retried: %+v Request: %s", err, getRecordsArgs) return err } - // reset the retry count after success retriedErrors = 0 - // Calculate size of records from read transaction - numBytes := 0 - for _, record := range getResp.Records { - numBytes = numBytes + len(record.Data) - } - - // Add to number of getRecords successful transactions - transactionNum++ - if transactionNum == 1 { - firstTransactionTime = getRecordsTransactionTime - lastCheckTime = firstTransactionTime - } - sc.processRecords(getRecordsStartTime, getResp.Records, getResp.MillisBehindLatest, recordCheckpointer) // The shard has been closed, so no new records can be read from it @@ -225,27 +213,6 @@ func (sc *PollingShardConsumer) getRecords() error { sc.recordProcessor.Shutdown(shutdownInput) return nil } - - // Each shard can support up to a maximum total data read rate of 2 MB per second via GetRecords. - // If a call to GetRecords returns 10 MB, subsequent calls made within the next 5 seconds throw an exception. - // ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html - // check for overspending of byte budget from getRecords call - currTime := time.Now() - timePassed := currTime.Sub(lastCheckTime) - lastCheckTime = currTime - - remBytes = remBytes + float64(timePassed.Seconds())*(MaxBytes/(float64(time.Second*5))) - if remBytes > MaxBytes { - remBytes = MaxBytes - } - if remBytes <= float64(numBytes) { - // Wait until cool down period has passed to prevent ProvisionedThroughputExceededException - coolDown := numBytes / MaxBytesPerSecond - time.Sleep(time.Duration(coolDown) * time.Second) - } else { - remBytes = remBytes - float64(numBytes) - } - shardIterator = getResp.NextShardIterator // Idle between each read, the user is responsible for checkpoint the progress @@ -264,3 +231,60 @@ func (sc *PollingShardConsumer) getRecords() error { } } } + +func (sc *PollingShardConsumer) waitASecond(timePassed time.Time) { + waitTime := time.Since(timePassed) + if waitTime < time.Second { + time.Sleep(time.Second - waitTime) + } +} + +func (sc *PollingShardConsumer) checkCoolOffPeriod() { + // Each shard can support up to a maximum total data read rate of 2 MB per second via GetRecords. + // If a call to GetRecords returns 10 MB, subsequent calls made within the next 5 seconds throw an exception. + // ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html + // check for overspending of byte budget from getRecords call + currentTime := rateLimitTimeNow() + timePassed := currentTime.Sub(sc.lastCheckTime) + sc.lastCheckTime = currentTime + sc.remBytes += timePassed.Seconds() * MaxBytesPerSecond + transactionReadRate := sc.bytesRead / (timePassed.Seconds() * BytesToMbConversion) + if sc.remBytes > MaxBytes { + sc.remBytes = MaxBytes + } + if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 { + // Wait until cool down period has passed to prevent ProvisionedThroughputExceededException + coolDown := sc.bytesRead / MaxBytesPerSecond + rateLimitSleep(time.Duration(coolDown * float64(time.Second))) + } else { + sc.remBytes -= sc.bytesRead + } +} + +func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + + if sc.bytesRead != 0 { + sc.checkCoolOffPeriod() + } + + // every new second, we get a fresh set of calls + if rateLimitTimeSince(sc.currTime) > time.Second { + sc.callsLeft = kinesisReadTPSLimit + sc.currTime = rateLimitTimeNow() + } + + if sc.callsLeft < 1 { + return nil, localTPSExceededError + } + + getResp, err := sc.kc.GetRecords(context.TODO(), gri) + + sc.callsLeft-- + // Calculate size of records from read transaction + sc.bytesRead = 0 + for _, record := range getResp.Records { + sc.bytesRead += float64(len(record.Data)) + } + + return getResp, err +} diff --git a/clientlibrary/worker/polling-shard-consumer_test.go b/clientlibrary/worker/polling-shard-consumer_test.go index b6a1fcf..859e24f 100644 --- a/clientlibrary/worker/polling-shard-consumer_test.go +++ b/clientlibrary/worker/polling-shard-consumer_test.go @@ -22,6 +22,7 @@ package worker import ( "context" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kinesis" @@ -44,6 +45,133 @@ func TestCallGetRecordsAPI(t *testing.T) { assert.Nil(t, err) assert.Equal(t, &ret, out) m1.AssertExpectations(t) + + // check that localTPSExceededError is thrown when trying more than 5 TPS + m2 := MockKinesisSubscriberGetter{} + psc2 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m2}, + callsLeft: 0, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 500 * time.Millisecond + } + out2, err2 := psc2.callGetRecordsAPI(&gri) + assert.Nil(t, out2) + assert.ErrorIs(t, err2, localTPSExceededError) + m2.AssertExpectations(t) + + // check that getRecords is called normally in bytesRead = 0 case + m3 := MockKinesisSubscriberGetter{} + ret3 := kinesis.GetRecordsOutput{} + m3.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret3, nil) + psc3 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m3}, + callsLeft: 2, + bytesRead: 0, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 2 * time.Second + } + out3, err3 := psc3.callGetRecordsAPI(&gri) + assert.Nil(t, err3) + assert.Equal(t, &ret3, out3) + m3.AssertExpectations(t) + + // check that correct cool off period is taken for 10mb in 1 second + testTime := time.Now() + m4 := MockKinesisSubscriberGetter{} + ret4 := kinesis.GetRecordsOutput{} + m4.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret4, nil) + psc4 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m4}, + callsLeft: 2, + bytesRead: MaxBytes, + lastCheckTime: testTime, + remBytes: MaxBytes, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 2 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime.Add(time.Second) + } + checkSleepVal := 0.0 + rateLimitSleep = func(d time.Duration) { + checkSleepVal = d.Seconds() + } + out4, err4 := psc4.callGetRecordsAPI(&gri) + assert.Nil(t, err4) + assert.Equal(t, &ret4, out4) + m4.AssertExpectations(t) + if checkSleepVal != 5 { + t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal) + } + + // check that no cool off period is taken for 6mb in 3 seconds + testTime2 := time.Now() + m5 := MockKinesisSubscriberGetter{} + ret5 := kinesis.GetRecordsOutput{} + m5.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret5, nil) + psc5 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m5}, + callsLeft: 2, + bytesRead: MaxBytesPerSecond * 3, + lastCheckTime: testTime2, + remBytes: MaxBytes, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 3 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime2.Add(time.Second * 3) + } + checkSleepVal2 := 0.0 + rateLimitSleep = func(d time.Duration) { + checkSleepVal2 = d.Seconds() + } + out5, err5 := psc5.callGetRecordsAPI(&gri) + assert.Nil(t, err5) + assert.Equal(t, &ret5, out5) + m5.AssertExpectations(t) + if checkSleepVal2 != 0 { + t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2) + } + + // check for correct cool off period with 8mb in .2 seconds with 6mb remaining + testTime3 := time.Now() + m6 := MockKinesisSubscriberGetter{} + ret6 := kinesis.GetRecordsOutput{} + m6.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret6, nil) + psc6 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m6}, + callsLeft: 2, + bytesRead: MaxBytesPerSecond * 4, + lastCheckTime: testTime3, + remBytes: MaxBytes * 3, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 3 * time.Second + } + rateLimitTimeNow = func() time.Time { + return testTime3.Add(time.Second / 5) + } + checkSleepVal3 := 0.0 + rateLimitSleep = func(d time.Duration) { + checkSleepVal3 = d.Seconds() + } + out6, err6 := psc6.callGetRecordsAPI(&gri) + assert.Nil(t, err6) + assert.Equal(t, &ret6, out6) + m5.AssertExpectations(t) + if checkSleepVal3 != 4 { + t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal3) + } + + // restore original func + rateLimitTimeNow = time.Now + rateLimitTimeSince = time.Since + rateLimitSleep = time.Sleep + } type MockKinesisSubscriberGetter struct {