fix: Check token bucket corner cases correctly.
Signed-off-by: John Calixto <jcalixto@vmware.com>
This commit is contained in:
parent
711b72932a
commit
987fada9d3
2 changed files with 167 additions and 8 deletions
|
|
@ -258,14 +258,16 @@ func (sc *PollingShardConsumer) checkCoolOffPeriod() (int, error) {
|
|||
secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds()
|
||||
sc.lastCheckTime = currentTime
|
||||
sc.remBytes += int(secondsPassed * MaxBytesPerSecond)
|
||||
transactionReadRate := float64(sc.bytesRead) / (secondsPassed * BytesToMbConversion)
|
||||
|
||||
if sc.remBytes > MaxBytes {
|
||||
sc.remBytes = MaxBytes
|
||||
}
|
||||
if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 {
|
||||
if sc.remBytes < 1 {
|
||||
// Wait until cool down period has passed to prevent ProvisionedThroughputExceededException
|
||||
coolDown := sc.bytesRead / MaxBytesPerSecond
|
||||
if sc.bytesRead%MaxBytesPerSecond > 0 {
|
||||
coolDown++
|
||||
}
|
||||
return coolDown, maxBytesExceededError
|
||||
} else {
|
||||
sc.remBytes -= sc.bytesRead
|
||||
|
|
|
|||
|
|
@ -86,6 +86,8 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
|||
// check that correct cool off period is taken for 10mb in 1 second
|
||||
testTime := time.Now()
|
||||
m4 := MockKinesisSubscriberGetter{}
|
||||
ret4 := kinesis.GetRecordsOutput{Records: nil}
|
||||
m4.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret4, nil)
|
||||
psc4 := PollingShardConsumer{
|
||||
commonShardConsumer: commonShardConsumer{kc: &m4},
|
||||
callsLeft: 2,
|
||||
|
|
@ -100,10 +102,10 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
|||
return testTime.Add(time.Second)
|
||||
}
|
||||
out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri)
|
||||
assert.Nil(t, out4)
|
||||
assert.Equal(t, maxBytesExceededError, err4)
|
||||
assert.Nil(t, err4)
|
||||
assert.Equal(t, &ret4, out4)
|
||||
m4.AssertExpectations(t)
|
||||
if checkSleepVal2 != 5 {
|
||||
if checkSleepVal2 != 0 {
|
||||
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2)
|
||||
}
|
||||
|
||||
|
|
@ -134,6 +136,8 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
|||
// check for correct cool off period with 8mb in .2 seconds with 6mb remaining
|
||||
testTime3 := time.Now()
|
||||
m6 := MockKinesisSubscriberGetter{}
|
||||
ret6 := kinesis.GetRecordsOutput{Records: nil}
|
||||
m6.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret6, nil)
|
||||
psc6 := PollingShardConsumer{
|
||||
commonShardConsumer: commonShardConsumer{kc: &m6},
|
||||
callsLeft: 2,
|
||||
|
|
@ -148,10 +152,10 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
|||
return testTime3.Add(time.Second / 5)
|
||||
}
|
||||
out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri)
|
||||
assert.Nil(t, out6)
|
||||
assert.Equal(t, err6, maxBytesExceededError)
|
||||
assert.Nil(t, err6)
|
||||
assert.Equal(t, &ret6, out6)
|
||||
m5.AssertExpectations(t)
|
||||
if checkSleepVal4 != 4 {
|
||||
if checkSleepVal4 != 0 {
|
||||
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
|
||||
}
|
||||
|
||||
|
|
@ -196,3 +200,156 @@ func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, para
|
|||
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestPollingShardConsumer_checkCoolOffPeriod(t *testing.T) {
|
||||
refTime := time.Now()
|
||||
type fields struct {
|
||||
lastCheckTime time.Time
|
||||
remBytes int
|
||||
bytesRead int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
timeNow time.Time
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"zero time max bytes to spend",
|
||||
fields{
|
||||
time.Time{},
|
||||
0,
|
||||
0,
|
||||
},
|
||||
refTime,
|
||||
0,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"same second, bytes still left to spend",
|
||||
fields{
|
||||
refTime,
|
||||
MaxBytesPerSecond,
|
||||
MaxBytesPerSecond - 1,
|
||||
},
|
||||
refTime,
|
||||
0,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"same second, not many but some bytes still left to spend",
|
||||
fields{
|
||||
refTime,
|
||||
8,
|
||||
MaxBytesPerSecond,
|
||||
},
|
||||
refTime,
|
||||
0,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"same second, 1 byte still left to spend",
|
||||
fields{
|
||||
refTime,
|
||||
1,
|
||||
MaxBytesPerSecond,
|
||||
},
|
||||
refTime,
|
||||
0,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"next second, bytes still left to spend",
|
||||
fields{
|
||||
refTime,
|
||||
42,
|
||||
1024,
|
||||
},
|
||||
refTime.Add(1 * time.Second),
|
||||
0,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"same second, max bytes per second already spent",
|
||||
fields{
|
||||
refTime,
|
||||
0,
|
||||
MaxBytesPerSecond,
|
||||
},
|
||||
refTime,
|
||||
1,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"same second, more than max bytes per second already spent",
|
||||
fields{
|
||||
refTime,
|
||||
0,
|
||||
MaxBytesPerSecond + 1,
|
||||
},
|
||||
refTime,
|
||||
2,
|
||||
true,
|
||||
},
|
||||
|
||||
// Kinesis prevents reading more than 10 MiB at once
|
||||
{
|
||||
"same second, 10 MiB read all at once",
|
||||
fields{
|
||||
refTime,
|
||||
0,
|
||||
10 * 1024 * 1024,
|
||||
},
|
||||
refTime,
|
||||
6,
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"same second, 10 MB read all at once",
|
||||
fields{
|
||||
refTime,
|
||||
0,
|
||||
10 * 1000 * 1000,
|
||||
},
|
||||
refTime,
|
||||
5,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"5 seconds ago, 10 MB read all at once",
|
||||
fields{
|
||||
refTime,
|
||||
0,
|
||||
10 * 1000 * 1000,
|
||||
},
|
||||
refTime.Add(5 * time.Second),
|
||||
0,
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sc := &PollingShardConsumer{
|
||||
lastCheckTime: tt.fields.lastCheckTime,
|
||||
remBytes: tt.fields.remBytes,
|
||||
bytesRead: tt.fields.bytesRead,
|
||||
}
|
||||
rateLimitTimeNow = func() time.Time {
|
||||
return tt.timeNow
|
||||
}
|
||||
got, err := sc.checkCoolOffPeriod()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PollingShardConsumer.checkCoolOffPeriod() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("PollingShardConsumer.checkCoolOffPeriod() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// restore original time.Now
|
||||
rateLimitTimeNow = time.Now
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue