Merge pull request #21 from vmware/spentakota_callGetRecordsAPI
fix: Handle ProvisionedThroughputExceededException throttling
This commit is contained in:
commit
42881449ce
4 changed files with 249 additions and 10 deletions
|
|
@ -136,6 +136,9 @@ const (
|
|||
|
||||
// DefaultLeaseSyncingIntervalMillis Number of milliseconds to wait before syncing with lease table (dynamodDB)
|
||||
DefaultLeaseSyncingIntervalMillis = 60000
|
||||
|
||||
// DefaultMaxRetryCount The default maximum number of retries in case of error
|
||||
DefaultMaxRetryCount = 5
|
||||
)
|
||||
|
||||
type (
|
||||
|
|
@ -283,6 +286,9 @@ type (
|
|||
|
||||
// LeaseSyncingTimeInterval The number of milliseconds to wait before syncing with lease table (dynamoDB)
|
||||
LeaseSyncingTimeIntervalMillis int
|
||||
|
||||
// MaxRetryCount The maximum number of retries in case of error
|
||||
MaxRetryCount int
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ func NewKinesisClientLibConfigWithCredentials(applicationName, streamName, regio
|
|||
LeaseStealingIntervalMillis: DefaultLeaseStealingIntervalMillis,
|
||||
LeaseStealingClaimTimeoutMillis: DefaultLeaseStealingClaimTimeoutMillis,
|
||||
LeaseSyncingTimeIntervalMillis: DefaultLeaseSyncingIntervalMillis,
|
||||
MaxRetryCount: DefaultMaxRetryCount,
|
||||
Logger: logger.GetDefaultLogger(),
|
||||
}
|
||||
}
|
||||
|
|
@ -211,6 +212,13 @@ func (c *KinesisClientLibConfiguration) WithLogger(logger logger.Logger) *Kinesi
|
|||
return c
|
||||
}
|
||||
|
||||
// WithMaxRetryCount sets the max retry count in case of error.
|
||||
func (c *KinesisClientLibConfiguration) WithMaxRetryCount(maxRetryCount int) *KinesisClientLibConfiguration {
|
||||
checkIsValuePositive("maxRetryCount", maxRetryCount)
|
||||
c.MaxRetryCount = maxRetryCount
|
||||
return c
|
||||
}
|
||||
|
||||
// WithMonitoringService sets the monitoring service to use to publish metrics.
|
||||
func (c *KinesisClientLibConfiguration) WithMonitoringService(mService metrics.MonitoringService) *KinesisClientLibConfiguration {
|
||||
// Nil case is handled downward (at worker creation) so no need to do it here.
|
||||
|
|
|
|||
|
|
@ -44,14 +44,33 @@ import (
|
|||
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
kinesisReadTPSLimit = 5
|
||||
MaxBytes = 10000000
|
||||
MaxBytesPerSecond = 2000000
|
||||
BytesToMbConversion = 1000000
|
||||
)
|
||||
|
||||
var (
|
||||
rateLimitTimeNow = time.Now
|
||||
rateLimitTimeSince = time.Since
|
||||
localTPSExceededError = errors.New("Error GetRecords TPS Exceeded")
|
||||
maxBytesExceededError = errors.New("Error GetRecords Max Bytes For Call Period Exceeded")
|
||||
)
|
||||
|
||||
// PollingShardConsumer is responsible for polling data records from a (specified) shard.
|
||||
// Note: PollingShardConsumer only deal with one shard.
|
||||
type PollingShardConsumer struct {
|
||||
commonShardConsumer
|
||||
streamName string
|
||||
stop *chan struct{}
|
||||
consumerID string
|
||||
mService metrics.MonitoringService
|
||||
streamName string
|
||||
stop *chan struct{}
|
||||
consumerID string
|
||||
mService metrics.MonitoringService
|
||||
currTime time.Time
|
||||
callsLeft int
|
||||
remBytes int
|
||||
lastCheckTime time.Time
|
||||
bytesRead int
|
||||
}
|
||||
|
||||
func (sc *PollingShardConsumer) getShardIterator() (*string, error) {
|
||||
|
|
@ -108,6 +127,12 @@ func (sc *PollingShardConsumer) getRecords() error {
|
|||
recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer)
|
||||
retriedErrors := 0
|
||||
|
||||
// define API call rate limit starting window
|
||||
sc.currTime = rateLimitTimeNow()
|
||||
sc.callsLeft = kinesisReadTPSLimit
|
||||
sc.bytesRead = 0
|
||||
sc.remBytes = MaxBytes
|
||||
|
||||
for {
|
||||
if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
|
||||
log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID)
|
||||
|
|
@ -135,14 +160,47 @@ func (sc *PollingShardConsumer) getRecords() error {
|
|||
Limit: aws.Int32(int32(sc.kclConfig.MaxRecords)),
|
||||
ShardIterator: shardIterator,
|
||||
}
|
||||
getResp, err := sc.callGetRecordsAPI(getRecordsArgs)
|
||||
getResp, coolDownPeriod, err := sc.callGetRecordsAPI(getRecordsArgs)
|
||||
if err != nil {
|
||||
//aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling
|
||||
var throughputExceededErr *types.ProvisionedThroughputExceededException
|
||||
var kmsThrottlingErr *types.KMSThrottlingException
|
||||
if errors.As(err, &throughputExceededErr) || errors.As(err, &kmsThrottlingErr) {
|
||||
if errors.As(err, &throughputExceededErr) {
|
||||
retriedErrors++
|
||||
if retriedErrors > sc.kclConfig.MaxRetryCount {
|
||||
log.Errorf("message", "Throughput Exceeded Error: "+
|
||||
"reached max retry count getting records from shard",
|
||||
"shardId", sc.shard.ID,
|
||||
"retryCount", retriedErrors,
|
||||
"error", err)
|
||||
return err
|
||||
}
|
||||
// If there is insufficient provisioned throughput on the stream,
|
||||
// subsequent calls made within the next 1 second throw ProvisionedThroughputExceededException.
|
||||
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
|
||||
sc.waitASecond(sc.currTime)
|
||||
continue
|
||||
}
|
||||
if err == localTPSExceededError {
|
||||
sc.waitASecond(sc.currTime)
|
||||
continue
|
||||
}
|
||||
if err == maxBytesExceededError {
|
||||
time.Sleep(time.Duration(coolDownPeriod) * time.Second)
|
||||
continue
|
||||
}
|
||||
if errors.As(err, &kmsThrottlingErr) {
|
||||
log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err)
|
||||
retriedErrors++
|
||||
// Greater than MaxRetryCount so we get the last retry
|
||||
if retriedErrors > sc.kclConfig.MaxRetryCount {
|
||||
log.Errorf("message", "KMS Throttling Error: "+
|
||||
"reached max retry count getting records from shard",
|
||||
"shardId", sc.shard.ID,
|
||||
"retryCount", retriedErrors,
|
||||
"error", err)
|
||||
return err
|
||||
}
|
||||
// exponential backoff
|
||||
// https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.Errors.html#Programming.Errors.RetryAndBackoff
|
||||
time.Sleep(time.Duration(math.Exp2(float64(retriedErrors))*100) * time.Millisecond)
|
||||
|
|
@ -182,7 +240,64 @@ func (sc *PollingShardConsumer) getRecords() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||
getResp, err := sc.kc.GetRecords(context.TODO(), gri)
|
||||
return getResp, err
|
||||
func (sc *PollingShardConsumer) waitASecond(timePassed time.Time) {
|
||||
waitTime := time.Since(timePassed)
|
||||
if waitTime < time.Second {
|
||||
time.Sleep(time.Second - waitTime)
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *PollingShardConsumer) checkCoolOffPeriod() (int, error) {
|
||||
// Each shard can support up to a maximum total data read rate of 2 MB per second via GetRecords.
|
||||
// If a call to GetRecords returns 10 MB, subsequent calls made within the next 5 seconds throw an exception.
|
||||
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
|
||||
// check for overspending of byte budget from getRecords call
|
||||
currentTime := rateLimitTimeNow()
|
||||
secondsPassed := currentTime.Sub(sc.lastCheckTime).Seconds()
|
||||
sc.lastCheckTime = currentTime
|
||||
sc.remBytes += int(secondsPassed * MaxBytesPerSecond)
|
||||
transactionReadRate := float64(sc.bytesRead) / (secondsPassed * BytesToMbConversion)
|
||||
|
||||
if sc.remBytes > MaxBytes {
|
||||
sc.remBytes = MaxBytes
|
||||
}
|
||||
if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 {
|
||||
// Wait until cool down period has passed to prevent ProvisionedThroughputExceededException
|
||||
coolDown := sc.bytesRead / MaxBytesPerSecond
|
||||
return coolDown, maxBytesExceededError
|
||||
} else {
|
||||
sc.remBytes -= sc.bytesRead
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, int, error) {
|
||||
if sc.bytesRead != 0 {
|
||||
coolDownPeriod, err := sc.checkCoolOffPeriod()
|
||||
if err != nil {
|
||||
return nil, coolDownPeriod, err
|
||||
}
|
||||
}
|
||||
// every new second, we get a fresh set of calls
|
||||
if rateLimitTimeSince(sc.currTime) > time.Second {
|
||||
sc.callsLeft = kinesisReadTPSLimit
|
||||
sc.currTime = rateLimitTimeNow()
|
||||
}
|
||||
|
||||
if sc.callsLeft < 1 {
|
||||
return nil, 0, localTPSExceededError
|
||||
}
|
||||
|
||||
getResp, err := sc.kc.GetRecords(context.TODO(), gri)
|
||||
sc.callsLeft--
|
||||
// Calculate size of records from read transaction
|
||||
sc.bytesRead = 0
|
||||
for _, record := range getResp.Records {
|
||||
sc.bytesRead += len(record.Data)
|
||||
}
|
||||
if sc.lastCheckTime.IsZero() {
|
||||
sc.lastCheckTime = rateLimitTimeNow()
|
||||
}
|
||||
|
||||
return getResp, 0, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ package worker
|
|||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||
|
|
@ -40,10 +41,119 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
|||
gri := kinesis.GetRecordsInput{
|
||||
ShardIterator: aws.String("shard-iterator-01"),
|
||||
}
|
||||
out, err := psc.callGetRecordsAPI(&gri)
|
||||
out, _, err := psc.callGetRecordsAPI(&gri)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, &ret, out)
|
||||
m1.AssertExpectations(t)
|
||||
|
||||
// check that localTPSExceededError is thrown when trying more than 5 TPS
|
||||
m2 := MockKinesisSubscriberGetter{}
|
||||
psc2 := PollingShardConsumer{
|
||||
commonShardConsumer: commonShardConsumer{kc: &m2},
|
||||
callsLeft: 0,
|
||||
}
|
||||
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||
return 500 * time.Millisecond
|
||||
}
|
||||
out2, _, err2 := psc2.callGetRecordsAPI(&gri)
|
||||
assert.Nil(t, out2)
|
||||
assert.ErrorIs(t, err2, localTPSExceededError)
|
||||
m2.AssertExpectations(t)
|
||||
|
||||
// check that getRecords is called normally in bytesRead = 0 case
|
||||
m3 := MockKinesisSubscriberGetter{}
|
||||
ret3 := kinesis.GetRecordsOutput{}
|
||||
m3.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret3, nil)
|
||||
psc3 := PollingShardConsumer{
|
||||
commonShardConsumer: commonShardConsumer{kc: &m3},
|
||||
callsLeft: 2,
|
||||
bytesRead: 0,
|
||||
}
|
||||
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||
return 2 * time.Second
|
||||
}
|
||||
out3, checkSleepVal, err3 := psc3.callGetRecordsAPI(&gri)
|
||||
assert.Nil(t, err3)
|
||||
assert.Equal(t, checkSleepVal, 0)
|
||||
assert.Equal(t, &ret3, out3)
|
||||
m3.AssertExpectations(t)
|
||||
|
||||
// check that correct cool off period is taken for 10mb in 1 second
|
||||
testTime := time.Now()
|
||||
m4 := MockKinesisSubscriberGetter{}
|
||||
psc4 := PollingShardConsumer{
|
||||
commonShardConsumer: commonShardConsumer{kc: &m4},
|
||||
callsLeft: 2,
|
||||
bytesRead: MaxBytes,
|
||||
lastCheckTime: testTime,
|
||||
remBytes: MaxBytes,
|
||||
}
|
||||
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||
return 2 * time.Second
|
||||
}
|
||||
rateLimitTimeNow = func() time.Time {
|
||||
return testTime.Add(time.Second)
|
||||
}
|
||||
out4, checkSleepVal2, err4 := psc4.callGetRecordsAPI(&gri)
|
||||
assert.Nil(t, out4)
|
||||
assert.Equal(t, maxBytesExceededError, err4)
|
||||
m4.AssertExpectations(t)
|
||||
if checkSleepVal2 != 5 {
|
||||
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2)
|
||||
}
|
||||
|
||||
// check that no cool off period is taken for 6mb in 3 seconds
|
||||
testTime2 := time.Now()
|
||||
m5 := MockKinesisSubscriberGetter{}
|
||||
ret5 := kinesis.GetRecordsOutput{}
|
||||
m5.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret5, nil)
|
||||
psc5 := PollingShardConsumer{
|
||||
commonShardConsumer: commonShardConsumer{kc: &m5},
|
||||
callsLeft: 2,
|
||||
bytesRead: MaxBytesPerSecond * 3,
|
||||
lastCheckTime: testTime2,
|
||||
remBytes: MaxBytes,
|
||||
}
|
||||
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||
return 3 * time.Second
|
||||
}
|
||||
rateLimitTimeNow = func() time.Time {
|
||||
return testTime2.Add(time.Second * 3)
|
||||
}
|
||||
out5, checkSleepVal3, err5 := psc5.callGetRecordsAPI(&gri)
|
||||
assert.Nil(t, err5)
|
||||
assert.Equal(t, checkSleepVal3, 0)
|
||||
assert.Equal(t, &ret5, out5)
|
||||
m5.AssertExpectations(t)
|
||||
|
||||
// check for correct cool off period with 8mb in .2 seconds with 6mb remaining
|
||||
testTime3 := time.Now()
|
||||
m6 := MockKinesisSubscriberGetter{}
|
||||
psc6 := PollingShardConsumer{
|
||||
commonShardConsumer: commonShardConsumer{kc: &m6},
|
||||
callsLeft: 2,
|
||||
bytesRead: MaxBytesPerSecond * 4,
|
||||
lastCheckTime: testTime3,
|
||||
remBytes: MaxBytesPerSecond * 3,
|
||||
}
|
||||
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||
return 3 * time.Second
|
||||
}
|
||||
rateLimitTimeNow = func() time.Time {
|
||||
return testTime3.Add(time.Second / 5)
|
||||
}
|
||||
out6, checkSleepVal4, err6 := psc6.callGetRecordsAPI(&gri)
|
||||
assert.Nil(t, out6)
|
||||
assert.Equal(t, err6, maxBytesExceededError)
|
||||
m5.AssertExpectations(t)
|
||||
if checkSleepVal4 != 4 {
|
||||
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
|
||||
}
|
||||
|
||||
// restore original func
|
||||
rateLimitTimeNow = time.Now
|
||||
rateLimitTimeSince = time.Since
|
||||
|
||||
}
|
||||
|
||||
type MockKinesisSubscriberGetter struct {
|
||||
|
|
|
|||
Loading…
Reference in a new issue