fix: add maxBytes per second getRecord check

Signed-off-by: Shiva Pentakota <spentakota@vmware.com>
This commit is contained in:
Shiva Pentakota 2023-01-24 16:28:22 -08:00
parent b5515931d1
commit 7d6b1c33d0
2 changed files with 167 additions and 15 deletions

View file

@ -46,24 +46,31 @@ import (
const ( const (
kinesisReadTPSLimit = 5 kinesisReadTPSLimit = 5
MaxBytes = 10000000
MaxBytesPerSecond = 2000000
BytesToMbConversion = 1000000
) )
var ( var (
rateLimitTimeNow = time.Now rateLimitTimeNow = time.Now
rateLimitTimeSince = time.Since rateLimitTimeSince = time.Since
localTPSExceededError = errors.New("Error GetRecords TPS Exceeded") localTPSExceededError = errors.New("Error GetRecords TPS Exceeded")
maxBytesExceededError = errors.New("Error GetRecords Max Bytes For Call Period Exceeded")
) )
// PollingShardConsumer is responsible for polling data records from a (specified) shard. // PollingShardConsumer is responsible for polling data records from a (specified) shard.
// Note: PollingShardConsumer only deal with one shard. // Note: PollingShardConsumer only deal with one shard.
type PollingShardConsumer struct { type PollingShardConsumer struct {
commonShardConsumer commonShardConsumer
streamName string streamName string
stop *chan struct{} stop *chan struct{}
consumerID string consumerID string
mService metrics.MonitoringService mService metrics.MonitoringService
currTime time.Time currTime time.Time
callsLeft int callsLeft int
remBytes int
lastCheckTime time.Time
bytesRead int
} }
func (sc *PollingShardConsumer) getShardIterator() (*string, error) { func (sc *PollingShardConsumer) getShardIterator() (*string, error) {
@ -123,6 +130,8 @@ func (sc *PollingShardConsumer) getRecords() error {
// define API call rate limit starting window // define API call rate limit starting window
sc.currTime = rateLimitTimeNow() sc.currTime = rateLimitTimeNow()
sc.callsLeft = kinesisReadTPSLimit sc.callsLeft = kinesisReadTPSLimit
sc.bytesRead = 0
sc.remBytes = MaxBytes
for { for {
if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) { if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
@ -151,15 +160,16 @@ func (sc *PollingShardConsumer) getRecords() error {
Limit: aws.Int32(int32(sc.kclConfig.MaxRecords)), Limit: aws.Int32(int32(sc.kclConfig.MaxRecords)),
ShardIterator: shardIterator, ShardIterator: shardIterator,
} }
getResp, err := sc.callGetRecordsAPI(getRecordsArgs) getResp, coolDownPeriod, err := sc.callGetRecordsAPI(getRecordsArgs)
if err != nil { if err != nil {
//aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling //aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling
var throughputExceededErr *types.ProvisionedThroughputExceededException var throughputExceededErr *types.ProvisionedThroughputExceededException
var kmsThrottlingErr *types.KMSThrottlingException var kmsThrottlingErr *types.KMSThrottlingException
if errors.As(err, &throughputExceededErr) || err == localTPSExceededError { if errors.As(err, &throughputExceededErr) {
retriedErrors++ retriedErrors++
if retriedErrors > sc.kclConfig.MaxRetryCount { if retriedErrors > sc.kclConfig.MaxRetryCount {
log.Errorf("message", "reached max retry count getting records from shard", log.Errorf("message", "Throughput Exceeded Error: "+
"reached max retry count getting records from shard",
"shardId", sc.shard.ID, "shardId", sc.shard.ID,
"retryCount", retriedErrors, "retryCount", retriedErrors,
"error", err) "error", err)
@ -171,12 +181,21 @@ func (sc *PollingShardConsumer) getRecords() error {
sc.waitASecond(sc.currTime) sc.waitASecond(sc.currTime)
continue continue
} }
if err == localTPSExceededError {
sc.waitASecond(sc.currTime)
continue
}
if err == maxBytesExceededError {
time.Sleep(time.Duration(coolDownPeriod) * time.Second)
continue
}
if errors.As(err, &kmsThrottlingErr) { if errors.As(err, &kmsThrottlingErr) {
log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err) log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err)
retriedErrors++ retriedErrors++
// Greater than MaxRetryCount so we get the last retry // Greater than MaxRetryCount so we get the last retry
if retriedErrors > sc.kclConfig.MaxRetryCount { if retriedErrors > sc.kclConfig.MaxRetryCount {
log.Errorf("message", "reached max retry count getting records from shard", log.Errorf("message", "KMS Throttling Error: "+
"reached max retry count getting records from shard",
"shardId", sc.shard.ID, "shardId", sc.shard.ID,
"retryCount", retriedErrors, "retryCount", retriedErrors,
"error", err) "error", err)
@ -228,7 +247,37 @@ func (sc *PollingShardConsumer) waitASecond(timePassed time.Time) {
} }
} }
func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { func (sc *PollingShardConsumer) checkCoolOffPeriod() (int, error) {
// Each shard can support up to a maximum total data read rate of 2 MB per second via GetRecords.
// If a call to GetRecords returns 10 MB, subsequent calls made within the next 5 seconds throw an exception.
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
// check for overspending of byte budget from getRecords call
currentTime := rateLimitTimeNow()
secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds()
sc.lastCheckTime = currentTime
sc.remBytes += int(secondsPassed * MaxBytesPerSecond)
transactionReadRate := float64(sc.bytesRead) / (secondsPassed * BytesToMbConversion)
if sc.remBytes > MaxBytes {
sc.remBytes = MaxBytes
}
if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 {
// Wait until cool down period has passed to prevent ProvisionedThroughputExceededException
coolDown := sc.bytesRead / MaxBytesPerSecond
return coolDown, maxBytesExceededError
} else {
sc.remBytes -= sc.bytesRead
}
return 0, nil
}
func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, int, error) {
if sc.bytesRead != 0 {
coolDownPeriod, err := sc.checkCoolOffPeriod()
if err != nil {
return nil, coolDownPeriod, err
}
}
// every new second, we get a fresh set of calls // every new second, we get a fresh set of calls
if rateLimitTimeSince(sc.currTime) > time.Second { if rateLimitTimeSince(sc.currTime) > time.Second {
sc.callsLeft = kinesisReadTPSLimit sc.callsLeft = kinesisReadTPSLimit
@ -236,11 +285,19 @@ func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput)
} }
if sc.callsLeft < 1 { if sc.callsLeft < 1 {
return nil, localTPSExceededError return nil, 0, localTPSExceededError
} }
getResp, err := sc.kc.GetRecords(context.TODO(), gri) getResp, err := sc.kc.GetRecords(context.TODO(), gri)
sc.callsLeft-- sc.callsLeft--
// Calculate size of records from read transaction
sc.bytesRead = 0
for _, record := range getResp.Records {
sc.bytesRead += len(record.Data)
}
if sc.lastCheckTime.IsZero() {
sc.lastCheckTime = rateLimitTimeNow()
}
return getResp, err return getResp, 0, err
} }

View file

@ -41,7 +41,7 @@ func TestCallGetRecordsAPI(t *testing.T) {
gri := kinesis.GetRecordsInput{ gri := kinesis.GetRecordsInput{
ShardIterator: aws.String("shard-iterator-01"), ShardIterator: aws.String("shard-iterator-01"),
} }
out, err := psc.callGetRecordsAPI(&gri) out, _, err := psc.callGetRecordsAPI(&gri)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, &ret, out) assert.Equal(t, &ret, out)
m1.AssertExpectations(t) m1.AssertExpectations(t)
@ -55,10 +55,105 @@ func TestCallGetRecordsAPI(t *testing.T) {
rateLimitTimeSince = func(t time.Time) time.Duration { rateLimitTimeSince = func(t time.Time) time.Duration {
return 500 * time.Millisecond return 500 * time.Millisecond
} }
out2, err2 := psc2.callGetRecordsAPI(&gri) out2, _, err2 := psc2.callGetRecordsAPI(&gri)
assert.Nil(t, out2) assert.Nil(t, out2)
assert.ErrorIs(t, err2, localTPSExceededError) assert.ErrorIs(t, err2, localTPSExceededError)
m2.AssertExpectations(t) m2.AssertExpectations(t)
// check that getRecords is called normally in bytesRead = 0 case
m3 := MockKinesisSubscriberGetter{}
ret3 := kinesis.GetRecordsOutput{}
m3.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret3, nil)
psc3 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m3},
callsLeft: 2,
bytesRead: 0,
}
rateLimitTimeSince = func(t time.Time) time.Duration {
return 2 * time.Second
}
out3, checkSleepVal, err3 := psc3.callGetRecordsAPI(&gri)
assert.Nil(t, err3)
assert.Equal(t, checkSleepVal, 0)
assert.Equal(t, &ret3, out3)
m3.AssertExpectations(t)
// check that correct cool off period is taken for 10mb in 1 second
testTime := time.Now()
m4 := MockKinesisSubscriberGetter{}
psc4 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m4},
callsLeft: 2,
bytesRead: MaxBytes,
lastCheckTime: testTime,
remBytes: MaxBytes,
}
rateLimitTimeSince = func(t time.Time) time.Duration {
return 2 * time.Second
}
rateLimitTimeNow = func() time.Time {
return testTime.Add(time.Second)
}
out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri)
assert.Nil(t, out4)
assert.Equal(t, maxBytesExceededError, err4)
m4.AssertExpectations(t)
if checkSleepVal2 != 5 {
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2)
}
// check that no cool off period is taken for 6mb in 3 seconds
testTime2 := time.Now()
m5 := MockKinesisSubscriberGetter{}
ret5 := kinesis.GetRecordsOutput{}
m5.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret5, nil)
psc5 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m5},
callsLeft: 2,
bytesRead: MaxBytesPerSecond * 3,
lastCheckTime: testTime2,
remBytes: MaxBytes,
}
rateLimitTimeSince = func(t time.Time) time.Duration {
return 3 * time.Second
}
rateLimitTimeNow = func() time.Time {
return testTime2.Add(time.Second * 3)
}
out5, checkSleepVal3, err5 := psc5.callGetRecordsAPI(&gri)
assert.Nil(t, err5)
assert.Equal(t, checkSleepVal3, 0)
assert.Equal(t, &ret5, out5)
m5.AssertExpectations(t)
// check for correct cool off period with 8mb in .2 seconds with 6mb remaining
testTime3 := time.Now()
m6 := MockKinesisSubscriberGetter{}
psc6 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m6},
callsLeft: 2,
bytesRead: MaxBytesPerSecond * 4,
lastCheckTime: testTime3,
remBytes: MaxBytesPerSecond * 3,
}
rateLimitTimeSince = func(t time.Time) time.Duration {
return 3 * time.Second
}
rateLimitTimeNow = func() time.Time {
return testTime3.Add(time.Second / 5)
}
out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri)
assert.Nil(t, out6)
assert.Equal(t, err6, maxBytesExceededError)
m5.AssertExpectations(t)
if checkSleepVal4 != 4 {
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
}
// restore original func
rateLimitTimeNow = time.Now
rateLimitTimeSince = time.Since
} }
type MockKinesisSubscriberGetter struct { type MockKinesisSubscriberGetter struct {