From 987fada9d3c28c05ac8e5cf34211665a678bcdf3 Mon Sep 17 00:00:00 2001 From: John Calixto Date: Thu, 23 Mar 2023 10:59:58 -0700 Subject: [PATCH] fix: Check token bucket corner cases correctly. Signed-off-by: John Calixto --- .../worker/polling-shard-consumer.go | 6 +- .../worker/polling-shard-consumer_test.go | 169 +++++++++++++++++- 2 files changed, 167 insertions(+), 8 deletions(-) diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index 6e19f5b..e0998ec 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -258,14 +258,16 @@ func (sc *PollingShardConsumer) checkCoolOffPeriod() (int, error) { secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds() sc.lastCheckTime = currentTime sc.remBytes += int(secondsPassed * MaxBytesPerSecond) - transactionReadRate := float64(sc.bytesRead) / (secondsPassed * BytesToMbConversion) if sc.remBytes > MaxBytes { sc.remBytes = MaxBytes } - if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 { + if sc.remBytes < 1 { // Wait until cool down period has passed to prevent ProvisionedThroughputExceededException coolDown := sc.bytesRead / MaxBytesPerSecond + if sc.bytesRead%MaxBytesPerSecond > 0 { + coolDown++ + } return coolDown, maxBytesExceededError } else { sc.remBytes -= sc.bytesRead diff --git a/clientlibrary/worker/polling-shard-consumer_test.go b/clientlibrary/worker/polling-shard-consumer_test.go index c94859d..736b2bd 100644 --- a/clientlibrary/worker/polling-shard-consumer_test.go +++ b/clientlibrary/worker/polling-shard-consumer_test.go @@ -86,6 +86,8 @@ func TestCallGetRecordsAPI(t *testing.T) { // check that correct cool off period is taken for 10mb in 1 second testTime := time.Now() m4 := MockKinesisSubscriberGetter{} + ret4 := kinesis.GetRecordsOutput{Records: nil} + m4.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret4, nil) psc4 := PollingShardConsumer{ commonShardConsumer: commonShardConsumer{kc: &m4}, callsLeft: 2, @@ -100,10 +102,10 @@ func TestCallGetRecordsAPI(t *testing.T) { return testTime.Add(time.Second) } out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri) - assert.Nil(t, out4) - assert.Equal(t, maxBytesExceededError, err4) + assert.Nil(t, err4) + assert.Equal(t, &ret4, out4) m4.AssertExpectations(t) - if checkSleepVal2 != 5 { + if checkSleepVal2 != 0 { t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2) } @@ -134,6 +136,8 @@ func TestCallGetRecordsAPI(t *testing.T) { // check for correct cool off period with 8mb in .2 seconds with 6mb remaining testTime3 := time.Now() m6 := MockKinesisSubscriberGetter{} + ret6 := kinesis.GetRecordsOutput{Records: nil} + m6.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret6, nil) psc6 := PollingShardConsumer{ commonShardConsumer: commonShardConsumer{kc: &m6}, callsLeft: 2, @@ -148,10 +152,10 @@ func TestCallGetRecordsAPI(t *testing.T) { return testTime3.Add(time.Second / 5) } out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri) - assert.Nil(t, out6) - assert.Equal(t, err6, maxBytesExceededError) + assert.Nil(t, err6) + assert.Equal(t, &ret6, out6) m5.AssertExpectations(t) - if checkSleepVal4 != 4 { + if checkSleepVal4 != 0 { t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4) } @@ -196,3 +200,156 @@ func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, para func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) { return nil, nil } + +func TestPollingShardConsumer_checkCoolOffPeriod(t *testing.T) { + refTime := time.Now() + type fields struct { + lastCheckTime time.Time + remBytes int + bytesRead int + } + tests := []struct { + name string + fields fields + timeNow time.Time + want int + wantErr bool + }{ + { + "zero time max bytes to spend", + fields{ + time.Time{}, + 0, + 0, + }, + refTime, + 0, + false, + }, + { + "same second, bytes still left to spend", + fields{ + refTime, + MaxBytesPerSecond, + MaxBytesPerSecond - 1, + }, + refTime, + 0, + false, + }, + { + "same second, not many but some bytes still left to spend", + fields{ + refTime, + 8, + MaxBytesPerSecond, + }, + refTime, + 0, + false, + }, + { + "same second, 1 byte still left to spend", + fields{ + refTime, + 1, + MaxBytesPerSecond, + }, + refTime, + 0, + false, + }, + { + "next second, bytes still left to spend", + fields{ + refTime, + 42, + 1024, + }, + refTime.Add(1 * time.Second), + 0, + false, + }, + { + "same second, max bytes per second already spent", + fields{ + refTime, + 0, + MaxBytesPerSecond, + }, + refTime, + 1, + true, + }, + { + "same second, more than max bytes per second already spent", + fields{ + refTime, + 0, + MaxBytesPerSecond + 1, + }, + refTime, + 2, + true, + }, + + // Kinesis prevents reading more than 10 MiB at once + { + "same second, 10 MiB read all at once", + fields{ + refTime, + 0, + 10 * 1024 * 1024, + }, + refTime, + 6, + true, + }, + + { + "same second, 10 MB read all at once", + fields{ + refTime, + 0, + 10 * 1000 * 1000, + }, + refTime, + 5, + true, + }, + { + "5 seconds ago, 10 MB read all at once", + fields{ + refTime, + 0, + 10 * 1000 * 1000, + }, + refTime.Add(5 * time.Second), + 0, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sc := &PollingShardConsumer{ + lastCheckTime: tt.fields.lastCheckTime, + remBytes: tt.fields.remBytes, + bytesRead: tt.fields.bytesRead, + } + rateLimitTimeNow = func() time.Time { + return tt.timeNow + } + got, err := sc.checkCoolOffPeriod() + if (err != nil) != tt.wantErr { + t.Errorf("PollingShardConsumer.checkCoolOffPeriod() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("PollingShardConsumer.checkCoolOffPeriod() = %v, want %v", got, tt.want) + } + }) + } + + // restore original time.Now + rateLimitTimeNow = time.Now +}