fix: Check token bucket corner cases correctly.

Signed-off-by: John Calixto <jcalixto@vmware.com>
This commit is contained in:
John Calixto 2023-03-23 10:59:58 -07:00
parent 711b72932a
commit 987fada9d3
2 changed files with 167 additions and 8 deletions

View file

@ -258,14 +258,16 @@ func (sc *PollingShardConsumer) checkCoolOffPeriod() (int, error) {
secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds()
sc.lastCheckTime = currentTime
sc.remBytes += int(secondsPassed * MaxBytesPerSecond)
transactionReadRate := float64(sc.bytesRead) / (secondsPassed * BytesToMbConversion)
if sc.remBytes > MaxBytes {
sc.remBytes = MaxBytes
}
if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 {
if sc.remBytes < 1 {
// Wait until cool down period has passed to prevent ProvisionedThroughputExceededException
coolDown := sc.bytesRead / MaxBytesPerSecond
if sc.bytesRead%MaxBytesPerSecond > 0 {
coolDown++
}
return coolDown, maxBytesExceededError
} else {
sc.remBytes -= sc.bytesRead

View file

@ -86,6 +86,8 @@ func TestCallGetRecordsAPI(t *testing.T) {
// check that correct cool off period is taken for 10mb in 1 second
testTime := time.Now()
m4 := MockKinesisSubscriberGetter{}
ret4 := kinesis.GetRecordsOutput{Records: nil}
m4.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret4, nil)
psc4 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m4},
callsLeft: 2,
@ -100,10 +102,10 @@ func TestCallGetRecordsAPI(t *testing.T) {
return testTime.Add(time.Second)
}
out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri)
assert.Nil(t, out4)
assert.Equal(t, maxBytesExceededError, err4)
assert.Nil(t, err4)
assert.Equal(t, &ret4, out4)
m4.AssertExpectations(t)
if checkSleepVal2 != 5 {
if checkSleepVal2 != 0 {
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2)
}
@ -134,6 +136,8 @@ func TestCallGetRecordsAPI(t *testing.T) {
// check for correct cool off period with 8mb in .2 seconds with 6mb remaining
testTime3 := time.Now()
m6 := MockKinesisSubscriberGetter{}
ret6 := kinesis.GetRecordsOutput{Records: nil}
m6.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret6, nil)
psc6 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m6},
callsLeft: 2,
@ -148,10 +152,10 @@ func TestCallGetRecordsAPI(t *testing.T) {
return testTime3.Add(time.Second / 5)
}
out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri)
assert.Nil(t, out6)
assert.Equal(t, err6, maxBytesExceededError)
assert.Nil(t, err6)
assert.Equal(t, &ret6, out6)
m5.AssertExpectations(t)
if checkSleepVal4 != 4 {
if checkSleepVal4 != 0 {
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
}
@ -196,3 +200,156 @@ func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, para
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
return nil, nil
}
func TestPollingShardConsumer_checkCoolOffPeriod(t *testing.T) {
refTime := time.Now()
type fields struct {
lastCheckTime time.Time
remBytes int
bytesRead int
}
tests := []struct {
name string
fields fields
timeNow time.Time
want int
wantErr bool
}{
{
"zero time max bytes to spend",
fields{
time.Time{},
0,
0,
},
refTime,
0,
false,
},
{
"same second, bytes still left to spend",
fields{
refTime,
MaxBytesPerSecond,
MaxBytesPerSecond - 1,
},
refTime,
0,
false,
},
{
"same second, not many but some bytes still left to spend",
fields{
refTime,
8,
MaxBytesPerSecond,
},
refTime,
0,
false,
},
{
"same second, 1 byte still left to spend",
fields{
refTime,
1,
MaxBytesPerSecond,
},
refTime,
0,
false,
},
{
"next second, bytes still left to spend",
fields{
refTime,
42,
1024,
},
refTime.Add(1 * time.Second),
0,
false,
},
{
"same second, max bytes per second already spent",
fields{
refTime,
0,
MaxBytesPerSecond,
},
refTime,
1,
true,
},
{
"same second, more than max bytes per second already spent",
fields{
refTime,
0,
MaxBytesPerSecond + 1,
},
refTime,
2,
true,
},
// Kinesis prevents reading more than 10 MiB at once
{
"same second, 10 MiB read all at once",
fields{
refTime,
0,
10 * 1024 * 1024,
},
refTime,
6,
true,
},
{
"same second, 10 MB read all at once",
fields{
refTime,
0,
10 * 1000 * 1000,
},
refTime,
5,
true,
},
{
"5 seconds ago, 10 MB read all at once",
fields{
refTime,
0,
10 * 1000 * 1000,
},
refTime.Add(5 * time.Second),
0,
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sc := &PollingShardConsumer{
lastCheckTime: tt.fields.lastCheckTime,
remBytes: tt.fields.remBytes,
bytesRead: tt.fields.bytesRead,
}
rateLimitTimeNow = func() time.Time {
return tt.timeNow
}
got, err := sc.checkCoolOffPeriod()
if (err != nil) != tt.wantErr {
t.Errorf("PollingShardConsumer.checkCoolOffPeriod() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PollingShardConsumer.checkCoolOffPeriod() = %v, want %v", got, tt.want)
}
})
}
// restore original time.Now
rateLimitTimeNow = time.Now
}