fix: Check token bucket corner cases correctly.
Signed-off-by: John Calixto <jcalixto@vmware.com>
This commit is contained in:
parent
711b72932a
commit
987fada9d3
2 changed files with 167 additions and 8 deletions
|
|
@ -258,14 +258,16 @@ func (sc *PollingShardConsumer) checkCoolOffPeriod() (int, error) {
|
||||||
secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds()
|
secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds()
|
||||||
sc.lastCheckTime = currentTime
|
sc.lastCheckTime = currentTime
|
||||||
sc.remBytes += int(secondsPassed * MaxBytesPerSecond)
|
sc.remBytes += int(secondsPassed * MaxBytesPerSecond)
|
||||||
transactionReadRate := float64(sc.bytesRead) / (secondsPassed * BytesToMbConversion)
|
|
||||||
|
|
||||||
if sc.remBytes > MaxBytes {
|
if sc.remBytes > MaxBytes {
|
||||||
sc.remBytes = MaxBytes
|
sc.remBytes = MaxBytes
|
||||||
}
|
}
|
||||||
if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 {
|
if sc.remBytes < 1 {
|
||||||
// Wait until cool down period has passed to prevent ProvisionedThroughputExceededException
|
// Wait until cool down period has passed to prevent ProvisionedThroughputExceededException
|
||||||
coolDown := sc.bytesRead / MaxBytesPerSecond
|
coolDown := sc.bytesRead / MaxBytesPerSecond
|
||||||
|
if sc.bytesRead%MaxBytesPerSecond > 0 {
|
||||||
|
coolDown++
|
||||||
|
}
|
||||||
return coolDown, maxBytesExceededError
|
return coolDown, maxBytesExceededError
|
||||||
} else {
|
} else {
|
||||||
sc.remBytes -= sc.bytesRead
|
sc.remBytes -= sc.bytesRead
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,8 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
// check that correct cool off period is taken for 10mb in 1 second
|
// check that correct cool off period is taken for 10mb in 1 second
|
||||||
testTime := time.Now()
|
testTime := time.Now()
|
||||||
m4 := MockKinesisSubscriberGetter{}
|
m4 := MockKinesisSubscriberGetter{}
|
||||||
|
ret4 := kinesis.GetRecordsOutput{Records: nil}
|
||||||
|
m4.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret4, nil)
|
||||||
psc4 := PollingShardConsumer{
|
psc4 := PollingShardConsumer{
|
||||||
commonShardConsumer: commonShardConsumer{kc: &m4},
|
commonShardConsumer: commonShardConsumer{kc: &m4},
|
||||||
callsLeft: 2,
|
callsLeft: 2,
|
||||||
|
|
@ -100,10 +102,10 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
return testTime.Add(time.Second)
|
return testTime.Add(time.Second)
|
||||||
}
|
}
|
||||||
out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri)
|
out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri)
|
||||||
assert.Nil(t, out4)
|
assert.Nil(t, err4)
|
||||||
assert.Equal(t, maxBytesExceededError, err4)
|
assert.Equal(t, &ret4, out4)
|
||||||
m4.AssertExpectations(t)
|
m4.AssertExpectations(t)
|
||||||
if checkSleepVal2 != 5 {
|
if checkSleepVal2 != 0 {
|
||||||
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2)
|
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -134,6 +136,8 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
// check for correct cool off period with 8mb in .2 seconds with 6mb remaining
|
// check for correct cool off period with 8mb in .2 seconds with 6mb remaining
|
||||||
testTime3 := time.Now()
|
testTime3 := time.Now()
|
||||||
m6 := MockKinesisSubscriberGetter{}
|
m6 := MockKinesisSubscriberGetter{}
|
||||||
|
ret6 := kinesis.GetRecordsOutput{Records: nil}
|
||||||
|
m6.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret6, nil)
|
||||||
psc6 := PollingShardConsumer{
|
psc6 := PollingShardConsumer{
|
||||||
commonShardConsumer: commonShardConsumer{kc: &m6},
|
commonShardConsumer: commonShardConsumer{kc: &m6},
|
||||||
callsLeft: 2,
|
callsLeft: 2,
|
||||||
|
|
@ -148,10 +152,10 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
return testTime3.Add(time.Second / 5)
|
return testTime3.Add(time.Second / 5)
|
||||||
}
|
}
|
||||||
out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri)
|
out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri)
|
||||||
assert.Nil(t, out6)
|
assert.Nil(t, err6)
|
||||||
assert.Equal(t, err6, maxBytesExceededError)
|
assert.Equal(t, &ret6, out6)
|
||||||
m5.AssertExpectations(t)
|
m5.AssertExpectations(t)
|
||||||
if checkSleepVal4 != 4 {
|
if checkSleepVal4 != 0 {
|
||||||
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
|
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -196,3 +200,156 @@ func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, para
|
||||||
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
|
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPollingShardConsumer_checkCoolOffPeriod(t *testing.T) {
|
||||||
|
refTime := time.Now()
|
||||||
|
type fields struct {
|
||||||
|
lastCheckTime time.Time
|
||||||
|
remBytes int
|
||||||
|
bytesRead int
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
timeNow time.Time
|
||||||
|
want int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"zero time max bytes to spend",
|
||||||
|
fields{
|
||||||
|
time.Time{},
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
},
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"same second, bytes still left to spend",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
MaxBytesPerSecond,
|
||||||
|
MaxBytesPerSecond - 1,
|
||||||
|
},
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"same second, not many but some bytes still left to spend",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
8,
|
||||||
|
MaxBytesPerSecond,
|
||||||
|
},
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"same second, 1 byte still left to spend",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
1,
|
||||||
|
MaxBytesPerSecond,
|
||||||
|
},
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"next second, bytes still left to spend",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
42,
|
||||||
|
1024,
|
||||||
|
},
|
||||||
|
refTime.Add(1 * time.Second),
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"same second, max bytes per second already spent",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
MaxBytesPerSecond,
|
||||||
|
},
|
||||||
|
refTime,
|
||||||
|
1,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"same second, more than max bytes per second already spent",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
MaxBytesPerSecond + 1,
|
||||||
|
},
|
||||||
|
refTime,
|
||||||
|
2,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Kinesis prevents reading more than 10 MiB at once
|
||||||
|
{
|
||||||
|
"same second, 10 MiB read all at once",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
10 * 1024 * 1024,
|
||||||
|
},
|
||||||
|
refTime,
|
||||||
|
6,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"same second, 10 MB read all at once",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
10 * 1000 * 1000,
|
||||||
|
},
|
||||||
|
refTime,
|
||||||
|
5,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"5 seconds ago, 10 MB read all at once",
|
||||||
|
fields{
|
||||||
|
refTime,
|
||||||
|
0,
|
||||||
|
10 * 1000 * 1000,
|
||||||
|
},
|
||||||
|
refTime.Add(5 * time.Second),
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
sc := &PollingShardConsumer{
|
||||||
|
lastCheckTime: tt.fields.lastCheckTime,
|
||||||
|
remBytes: tt.fields.remBytes,
|
||||||
|
bytesRead: tt.fields.bytesRead,
|
||||||
|
}
|
||||||
|
rateLimitTimeNow = func() time.Time {
|
||||||
|
return tt.timeNow
|
||||||
|
}
|
||||||
|
got, err := sc.checkCoolOffPeriod()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("PollingShardConsumer.checkCoolOffPeriod() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("PollingShardConsumer.checkCoolOffPeriod() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// restore original time.Now
|
||||||
|
rateLimitTimeNow = time.Now
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue