fix: add getRecords TPS rate limiting

Signed-off-by: Shiva Pentakota <spentakota@vmware.com>
This commit is contained in:
Shiva Pentakota 2023-01-24 11:56:29 -08:00
parent 981dc2df11
commit 66006caf89
2 changed files with 58 additions and 1 deletions

View file

@ -44,6 +44,16 @@ import (
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics" "github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics"
) )
const (
kinesisReadTPSLimit = 5
)
var (
rateLimitTimeNow = time.Now
rateLimitTimeSince = time.Since
localTPSExceededError = errors.New("Error GetRecords TPS Exceeded")
)
// PollingShardConsumer is responsible for polling data records from a (specified) shard. // PollingShardConsumer is responsible for polling data records from a (specified) shard.
// Note: PollingShardConsumer only deal with one shard. // Note: PollingShardConsumer only deal with one shard.
type PollingShardConsumer struct { type PollingShardConsumer struct {
@ -52,6 +62,8 @@ type PollingShardConsumer struct {
stop *chan struct{} stop *chan struct{}
consumerID string consumerID string
mService metrics.MonitoringService mService metrics.MonitoringService
currTime time.Time
callsLeft int
} }
func (sc *PollingShardConsumer) getShardIterator() (*string, error) { func (sc *PollingShardConsumer) getShardIterator() (*string, error) {
@ -108,6 +120,10 @@ func (sc *PollingShardConsumer) getRecords() error {
recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer) recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer)
retriedErrors := 0 retriedErrors := 0
// define API call rate limit starting window
sc.currTime = rateLimitTimeNow()
sc.callsLeft = kinesisReadTPSLimit
for { for {
if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) { if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID) log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID)
@ -140,7 +156,14 @@ func (sc *PollingShardConsumer) getRecords() error {
//aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling //aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling
var throughputExceededErr *types.ProvisionedThroughputExceededException var throughputExceededErr *types.ProvisionedThroughputExceededException
var kmsThrottlingErr *types.KMSThrottlingException var kmsThrottlingErr *types.KMSThrottlingException
if errors.As(err, &throughputExceededErr) || errors.As(err, &kmsThrottlingErr) { if errors.As(err, &throughputExceededErr) || err == localTPSExceededError {
// If there is insufficient provisioned throughput on the stream,
// subsequent calls made within the next 1 second throw ProvisionedThroughputExceededException.
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
sc.waitASecond(sc.currTime)
continue
}
if errors.As(err, &kmsThrottlingErr) {
log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err) log.Errorf("Error getting records from shard %v: %+v", sc.shard.ID, err)
retriedErrors++ retriedErrors++
// exponential backoff // exponential backoff
@ -182,7 +205,26 @@ func (sc *PollingShardConsumer) getRecords() error {
} }
} }
func (sc *PollingShardConsumer) waitASecond(timePassed time.Time) {
waitTime := time.Since(timePassed)
if waitTime < time.Second {
time.Sleep(time.Second - waitTime)
}
}
func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
// every new second, we get a fresh set of calls
if rateLimitTimeSince(sc.currTime) > time.Second {
sc.callsLeft = kinesisReadTPSLimit
sc.currTime = rateLimitTimeNow()
}
if sc.callsLeft < 1 {
return nil, localTPSExceededError
}
getResp, err := sc.kc.GetRecords(context.TODO(), gri) getResp, err := sc.kc.GetRecords(context.TODO(), gri)
sc.callsLeft--
return getResp, err return getResp, err
} }

View file

@ -22,6 +22,7 @@ package worker
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/kinesis"
@ -44,6 +45,20 @@ func TestCallGetRecordsAPI(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, &ret, out) assert.Equal(t, &ret, out)
m1.AssertExpectations(t) m1.AssertExpectations(t)
// check that localTPSExceededError is thrown when trying more than 5 TPS
m2 := MockKinesisSubscriberGetter{}
psc2 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m2},
callsLeft: 0,
}
rateLimitTimeSince = func(t time.Time) time.Duration {
return 500 * time.Millisecond
}
out2, err2 := psc2.callGetRecordsAPI(&gri)
assert.Nil(t, out2)
assert.ErrorIs(t, err2, localTPSExceededError)
m2.AssertExpectations(t)
} }
type MockKinesisSubscriberGetter struct { type MockKinesisSubscriberGetter struct {