fix: add getRecords TPS rate limiting
Signed-off-by: Shiva Pentakota <spentakota@vmware.com>
This commit is contained in:
parent
981dc2df11
commit
66006caf89
2 changed files with 58 additions and 1 deletions
|
|
@ -44,6 +44,16 @@ import (
|
||||||
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics"
|
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
kinesisReadTPSLimit = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
rateLimitTimeNow = time.Now
|
||||||
|
rateLimitTimeSince = time.Since
|
||||||
|
localTPSExceededError = errors.New("Error GetRecords TPS Exceeded")
|
||||||
|
)
|
||||||
|
|
||||||
// PollingShardConsumer is responsible for polling data records from a (specified) shard.
|
// PollingShardConsumer is responsible for polling data records from a (specified) shard.
|
||||||
// Note: PollingShardConsumer only deal with one shard.
|
// Note: PollingShardConsumer only deal with one shard.
|
||||||
type PollingShardConsumer struct {
|
type PollingShardConsumer struct {
|
||||||
|
|
@ -52,6 +62,8 @@ type PollingShardConsumer struct {
|
||||||
stop *chan struct{}
|
stop *chan struct{}
|
||||||
consumerID string
|
consumerID string
|
||||||
mService metrics.MonitoringService
|
mService metrics.MonitoringService
|
||||||
|
currTime time.Time
|
||||||
|
callsLeft int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *PollingShardConsumer) getShardIterator() (*string, error) {
|
func (sc *PollingShardConsumer) getShardIterator() (*string, error) {
|
||||||
|
|
@ -108,6 +120,10 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer)
|
recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer)
|
||||||
retriedErrors := 0
|
retriedErrors := 0
|
||||||
|
|
||||||
|
// define API call rate limit starting window
|
||||||
|
sc.currTime = rateLimitTimeNow()
|
||||||
|
sc.callsLeft = kinesisReadTPSLimit
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
|
if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
|
||||||
log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID)
|
log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID)
|
||||||
|
|
@ -140,7 +156,14 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
//aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling
|
//aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling
|
||||||
var throughputExceededErr *types.ProvisionedThroughputExceededException
|
var throughputExceededErr *types.ProvisionedThroughputExceededException
|
||||||
var kmsThrottlingErr *types.KMSThrottlingException
|
var kmsThrottlingErr *types.KMSThrottlingException
|
||||||
if errors.As(err, &throughputExceededErr) || errors.As(err, &kmsThrottlingErr) {
|
if errors.As(err, &throughputExceededErr) || err == localTPSExceededError {
|
||||||
|
// If there is insufficient provisioned throughput on the stream,
|
||||||
|
// subsequent calls made within the next 1 second throw ProvisionedThroughputExceededException.
|
||||||
|
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
|
||||||
|
sc.waitASecond(sc.currTime)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errors.As(err, &kmsThrottlingErr) {
|
||||||
log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err)
|
log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err)
|
||||||
retriedErrors++
|
retriedErrors++
|
||||||
// exponential backoff
|
// exponential backoff
|
||||||
|
|
@ -182,7 +205,26 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sc *PollingShardConsumer) waitASecond(timePassed time.Time) {
|
||||||
|
waitTime := time.Since(timePassed)
|
||||||
|
if waitTime < time.Second {
|
||||||
|
time.Sleep(time.Second - waitTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
// every new second, we get a fresh set of calls
|
||||||
|
if rateLimitTimeSince(sc.currTime) > time.Second {
|
||||||
|
sc.callsLeft = kinesisReadTPSLimit
|
||||||
|
sc.currTime = rateLimitTimeNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc.callsLeft < 1 {
|
||||||
|
return nil, localTPSExceededError
|
||||||
|
}
|
||||||
|
|
||||||
getResp, err := sc.kc.GetRecords(context.TODO(), gri)
|
getResp, err := sc.kc.GetRecords(context.TODO(), gri)
|
||||||
|
sc.callsLeft--
|
||||||
|
|
||||||
return getResp, err
|
return getResp, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ package worker
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||||
|
|
@ -44,6 +45,20 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, &ret, out)
|
assert.Equal(t, &ret, out)
|
||||||
m1.AssertExpectations(t)
|
m1.AssertExpectations(t)
|
||||||
|
|
||||||
|
// check that localTPSExceededError is thrown when trying more than 5 TPS
|
||||||
|
m2 := MockKinesisSubscriberGetter{}
|
||||||
|
psc2 := PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{kc: &m2},
|
||||||
|
callsLeft: 0,
|
||||||
|
}
|
||||||
|
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||||
|
return 500 * time.Millisecond
|
||||||
|
}
|
||||||
|
out2, err2 := psc2.callGetRecordsAPI(&gri)
|
||||||
|
assert.Nil(t, out2)
|
||||||
|
assert.ErrorIs(t, err2, localTPSExceededError)
|
||||||
|
m2.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockKinesisSubscriberGetter struct {
|
type MockKinesisSubscriberGetter struct {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue