fix: added callGetRecordsAPI tests
Signed-off-by: Shiva Pentakota <spentakota@vmware.com>
This commit is contained in:
parent
adcff0b7bb
commit
ae3763e478
2 changed files with 218 additions and 66 deletions
|
|
@ -45,9 +45,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
kinesisReadTPSLimit = 5
|
||||||
MaxBytes = 10000000.0
|
MaxBytes = 10000000.0
|
||||||
MaxBytesPerSecond = 2000000.0
|
MaxBytesPerSecond = 2000000.0
|
||||||
MaxReadTransactionsPerSecond = 5
|
BytesToMbConversion = 1000000.0
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
rateLimitTimeNow = time.Now
|
||||||
|
rateLimitTimeSince = time.Since
|
||||||
|
rateLimitSleep = time.Sleep
|
||||||
|
|
||||||
|
localTPSExceededError = errors.New("Error GetRecords TPS Exceeded")
|
||||||
)
|
)
|
||||||
|
|
||||||
// PollingShardConsumer is responsible for polling data records from a (specified) shard.
|
// PollingShardConsumer is responsible for polling data records from a (specified) shard.
|
||||||
|
|
@ -58,6 +67,11 @@ type PollingShardConsumer struct {
|
||||||
stop *chan struct{}
|
stop *chan struct{}
|
||||||
consumerID string
|
consumerID string
|
||||||
mService metrics.MonitoringService
|
mService metrics.MonitoringService
|
||||||
|
currTime time.Time
|
||||||
|
callsLeft int
|
||||||
|
remBytes float64
|
||||||
|
lastCheckTime time.Time
|
||||||
|
bytesRead float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *PollingShardConsumer) getShardIterator() (*string, error) {
|
func (sc *PollingShardConsumer) getShardIterator() (*string, error) {
|
||||||
|
|
@ -113,10 +127,12 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
|
|
||||||
recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer)
|
recordCheckpointer := NewRecordProcessorCheckpoint(sc.shard, sc.checkpointer)
|
||||||
retriedErrors := 0
|
retriedErrors := 0
|
||||||
transactionNum := 0
|
|
||||||
remBytes := MaxBytes
|
// define API call rate limit starting window
|
||||||
var lastCheckTime time.Time
|
sc.currTime = rateLimitTimeNow()
|
||||||
var firstTransactionTime time.Time
|
sc.callsLeft = kinesisReadTPSLimit
|
||||||
|
sc.bytesRead = 0
|
||||||
|
sc.remBytes = MaxBytes
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
|
if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
|
||||||
|
|
@ -139,30 +155,18 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
getRecordsStartTime := time.Now()
|
getRecordsStartTime := time.Now()
|
||||||
|
|
||||||
log.Debugf("Trying to read %d record from iterator: %v", sc.kclConfig.MaxRecords, aws.ToString(shardIterator))
|
log.Debugf("Trying to read %d record from iterator: %v", sc.kclConfig.MaxRecords, aws.ToString(shardIterator))
|
||||||
|
|
||||||
|
// Get records from stream and retry as needed
|
||||||
getRecordsArgs := &kinesis.GetRecordsInput{
|
getRecordsArgs := &kinesis.GetRecordsInput{
|
||||||
Limit: aws.Int32(int32(sc.kclConfig.MaxRecords)),
|
Limit: aws.Int32(int32(sc.kclConfig.MaxRecords)),
|
||||||
ShardIterator: shardIterator,
|
ShardIterator: shardIterator,
|
||||||
}
|
}
|
||||||
|
getResp, err := sc.callGetRecordsAPI(getRecordsArgs)
|
||||||
// Each shard can support up to five read transactions per second.
|
|
||||||
if transactionNum > MaxReadTransactionsPerSecond {
|
|
||||||
transactionNum = 0
|
|
||||||
timeDiff := time.Since(firstTransactionTime)
|
|
||||||
if timeDiff < time.Second {
|
|
||||||
time.Sleep(timeDiff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get records from stream and retry as needed
|
|
||||||
// Each read transaction can provide up to 10,000 records with an upper quota of 10 MB per transaction.
|
|
||||||
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
|
|
||||||
getResp, err := sc.kc.GetRecords(context.TODO(), getRecordsArgs)
|
|
||||||
getRecordsTransactionTime := time.Now()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling
|
//aws-sdk-go-v2 https://github.com/aws/aws-sdk-go-v2/blob/main/CHANGELOG.md#error-handling
|
||||||
var throughputExceededErr *types.ProvisionedThroughputExceededException
|
var throughputExceededErr *types.ProvisionedThroughputExceededException
|
||||||
var kmsThrottlingErr *types.KMSThrottlingException
|
var kmsThrottlingErr *types.KMSThrottlingException
|
||||||
if errors.As(err, &throughputExceededErr) {
|
if errors.As(err, &throughputExceededErr) || err == localTPSExceededError {
|
||||||
retriedErrors++
|
retriedErrors++
|
||||||
if retriedErrors > sc.kclConfig.MaxRetryCount {
|
if retriedErrors > sc.kclConfig.MaxRetryCount {
|
||||||
log.Errorf("message", "reached max retry count getting records from shard",
|
log.Errorf("message", "reached max retry count getting records from shard",
|
||||||
|
|
@ -174,10 +178,7 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
// If there is insufficient provisioned throughput on the stream,
|
// If there is insufficient provisioned throughput on the stream,
|
||||||
// subsequent calls made within the next 1 second throw ProvisionedThroughputExceededException.
|
// subsequent calls made within the next 1 second throw ProvisionedThroughputExceededException.
|
||||||
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
|
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
|
||||||
waitTime := time.Since(getRecordsTransactionTime)
|
sc.waitASecond(sc.currTime)
|
||||||
if waitTime < time.Second {
|
|
||||||
time.Sleep(time.Second - waitTime)
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errors.As(err, &kmsThrottlingErr) {
|
if errors.As(err, &kmsThrottlingErr) {
|
||||||
|
|
@ -196,26 +197,13 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
time.Sleep(time.Duration(math.Exp2(float64(retriedErrors))*100) * time.Millisecond)
|
time.Sleep(time.Duration(math.Exp2(float64(retriedErrors))*100) * time.Millisecond)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Errorf("Error getting records from Kinesis that cannot be retried: %+v Request: %s", err, getRecordsArgs)
|
log.Errorf("Error getting records from Kinesis that cannot be retried: %+v Request: %s", err, getRecordsArgs)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset the retry count after success
|
// reset the retry count after success
|
||||||
retriedErrors = 0
|
retriedErrors = 0
|
||||||
|
|
||||||
// Calculate size of records from read transaction
|
|
||||||
numBytes := 0
|
|
||||||
for _, record := range getResp.Records {
|
|
||||||
numBytes = numBytes + len(record.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add to number of getRecords successful transactions
|
|
||||||
transactionNum++
|
|
||||||
if transactionNum == 1 {
|
|
||||||
firstTransactionTime = getRecordsTransactionTime
|
|
||||||
lastCheckTime = firstTransactionTime
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.processRecords(getRecordsStartTime, getResp.Records, getResp.MillisBehindLatest, recordCheckpointer)
|
sc.processRecords(getRecordsStartTime, getResp.Records, getResp.MillisBehindLatest, recordCheckpointer)
|
||||||
|
|
||||||
// The shard has been closed, so no new records can be read from it
|
// The shard has been closed, so no new records can be read from it
|
||||||
|
|
@ -225,27 +213,6 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
sc.recordProcessor.Shutdown(shutdownInput)
|
sc.recordProcessor.Shutdown(shutdownInput)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each shard can support up to a maximum total data read rate of 2 MB per second via GetRecords.
|
|
||||||
// If a call to GetRecords returns 10 MB, subsequent calls made within the next 5 seconds throw an exception.
|
|
||||||
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
|
|
||||||
// check for overspending of byte budget from getRecords call
|
|
||||||
currTime := time.Now()
|
|
||||||
timePassed := currTime.Sub(lastCheckTime)
|
|
||||||
lastCheckTime = currTime
|
|
||||||
|
|
||||||
remBytes = remBytes + float64(timePassed.Seconds())*(MaxBytes/(float64(time.Second*5)))
|
|
||||||
if remBytes > MaxBytes {
|
|
||||||
remBytes = MaxBytes
|
|
||||||
}
|
|
||||||
if remBytes <= float64(numBytes) {
|
|
||||||
// Wait until cool down period has passed to prevent ProvisionedThroughputExceededException
|
|
||||||
coolDown := numBytes / MaxBytesPerSecond
|
|
||||||
time.Sleep(time.Duration(coolDown) * time.Second)
|
|
||||||
} else {
|
|
||||||
remBytes = remBytes - float64(numBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
shardIterator = getResp.NextShardIterator
|
shardIterator = getResp.NextShardIterator
|
||||||
|
|
||||||
// Idle between each read, the user is responsible for checkpoint the progress
|
// Idle between each read, the user is responsible for checkpoint the progress
|
||||||
|
|
@ -264,3 +231,60 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sc *PollingShardConsumer) waitASecond(timePassed time.Time) {
|
||||||
|
waitTime := time.Since(timePassed)
|
||||||
|
if waitTime < time.Second {
|
||||||
|
time.Sleep(time.Second - waitTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *PollingShardConsumer) checkCoolOffPeriod() {
|
||||||
|
// Each shard can support up to a maximum total data read rate of 2 MB per second via GetRecords.
|
||||||
|
// If a call to GetRecords returns 10 MB, subsequent calls made within the next 5 seconds throw an exception.
|
||||||
|
// ref: https://docs.aws.amazon.com/streams/latest/dev/service-sizes-and-limits.html
|
||||||
|
// check for overspending of byte budget from getRecords call
|
||||||
|
currentTime := rateLimitTimeNow()
|
||||||
|
timePassed := currentTime.Sub(sc.lastCheckTime)
|
||||||
|
sc.lastCheckTime = currentTime
|
||||||
|
sc.remBytes += timePassed.Seconds() * MaxBytesPerSecond
|
||||||
|
transactionReadRate := sc.bytesRead / (timePassed.Seconds() * BytesToMbConversion)
|
||||||
|
if sc.remBytes > MaxBytes {
|
||||||
|
sc.remBytes = MaxBytes
|
||||||
|
}
|
||||||
|
if sc.remBytes <= sc.bytesRead || transactionReadRate > 2 {
|
||||||
|
// Wait until cool down period has passed to prevent ProvisionedThroughputExceededException
|
||||||
|
coolDown := sc.bytesRead / MaxBytesPerSecond
|
||||||
|
rateLimitSleep(time.Duration(coolDown * float64(time.Second)))
|
||||||
|
} else {
|
||||||
|
sc.remBytes -= sc.bytesRead
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
|
||||||
|
if sc.bytesRead != 0 {
|
||||||
|
sc.checkCoolOffPeriod()
|
||||||
|
}
|
||||||
|
|
||||||
|
// every new second, we get a fresh set of calls
|
||||||
|
if rateLimitTimeSince(sc.currTime) > time.Second {
|
||||||
|
sc.callsLeft = kinesisReadTPSLimit
|
||||||
|
sc.currTime = rateLimitTimeNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc.callsLeft < 1 {
|
||||||
|
return nil, localTPSExceededError
|
||||||
|
}
|
||||||
|
|
||||||
|
getResp, err := sc.kc.GetRecords(context.TODO(), gri)
|
||||||
|
|
||||||
|
sc.callsLeft--
|
||||||
|
// Calculate size of records from read transaction
|
||||||
|
sc.bytesRead = 0
|
||||||
|
for _, record := range getResp.Records {
|
||||||
|
sc.bytesRead += float64(len(record.Data))
|
||||||
|
}
|
||||||
|
|
||||||
|
return getResp, err
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ package worker
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||||
|
|
@ -44,6 +45,133 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, &ret, out)
|
assert.Equal(t, &ret, out)
|
||||||
m1.AssertExpectations(t)
|
m1.AssertExpectations(t)
|
||||||
|
|
||||||
|
// check that localTPSExceededError is thrown when trying more than 5 TPS
|
||||||
|
m2 := MockKinesisSubscriberGetter{}
|
||||||
|
psc2 := PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{kc: &m2},
|
||||||
|
callsLeft: 0,
|
||||||
|
}
|
||||||
|
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||||
|
return 500 * time.Millisecond
|
||||||
|
}
|
||||||
|
out2, err2 := psc2.callGetRecordsAPI(&gri)
|
||||||
|
assert.Nil(t, out2)
|
||||||
|
assert.ErrorIs(t, err2, localTPSExceededError)
|
||||||
|
m2.AssertExpectations(t)
|
||||||
|
|
||||||
|
// check that getRecords is called normally in bytesRead = 0 case
|
||||||
|
m3 := MockKinesisSubscriberGetter{}
|
||||||
|
ret3 := kinesis.GetRecordsOutput{}
|
||||||
|
m3.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret3, nil)
|
||||||
|
psc3 := PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{kc: &m3},
|
||||||
|
callsLeft: 2,
|
||||||
|
bytesRead: 0,
|
||||||
|
}
|
||||||
|
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||||
|
return 2 * time.Second
|
||||||
|
}
|
||||||
|
out3, err3 := psc3.callGetRecordsAPI(&gri)
|
||||||
|
assert.Nil(t, err3)
|
||||||
|
assert.Equal(t, &ret3, out3)
|
||||||
|
m3.AssertExpectations(t)
|
||||||
|
|
||||||
|
// check that correct cool off period is taken for 10mb in 1 second
|
||||||
|
testTime := time.Now()
|
||||||
|
m4 := MockKinesisSubscriberGetter{}
|
||||||
|
ret4 := kinesis.GetRecordsOutput{}
|
||||||
|
m4.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret4, nil)
|
||||||
|
psc4 := PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{kc: &m4},
|
||||||
|
callsLeft: 2,
|
||||||
|
bytesRead: MaxBytes,
|
||||||
|
lastCheckTime: testTime,
|
||||||
|
remBytes: MaxBytes,
|
||||||
|
}
|
||||||
|
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||||
|
return 2 * time.Second
|
||||||
|
}
|
||||||
|
rateLimitTimeNow = func() time.Time {
|
||||||
|
return testTime.Add(time.Second)
|
||||||
|
}
|
||||||
|
checkSleepVal := 0.0
|
||||||
|
rateLimitSleep = func(d time.Duration) {
|
||||||
|
checkSleepVal = d.Seconds()
|
||||||
|
}
|
||||||
|
out4, err4 := psc4.callGetRecordsAPI(&gri)
|
||||||
|
assert.Nil(t, err4)
|
||||||
|
assert.Equal(t, &ret4, out4)
|
||||||
|
m4.AssertExpectations(t)
|
||||||
|
if checkSleepVal != 5 {
|
||||||
|
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that no cool off period is taken for 6mb in 3 seconds
|
||||||
|
testTime2 := time.Now()
|
||||||
|
m5 := MockKinesisSubscriberGetter{}
|
||||||
|
ret5 := kinesis.GetRecordsOutput{}
|
||||||
|
m5.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret5, nil)
|
||||||
|
psc5 := PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{kc: &m5},
|
||||||
|
callsLeft: 2,
|
||||||
|
bytesRead: MaxBytesPerSecond * 3,
|
||||||
|
lastCheckTime: testTime2,
|
||||||
|
remBytes: MaxBytes,
|
||||||
|
}
|
||||||
|
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||||
|
return 3 * time.Second
|
||||||
|
}
|
||||||
|
rateLimitTimeNow = func() time.Time {
|
||||||
|
return testTime2.Add(time.Second * 3)
|
||||||
|
}
|
||||||
|
checkSleepVal2 := 0.0
|
||||||
|
rateLimitSleep = func(d time.Duration) {
|
||||||
|
checkSleepVal2 = d.Seconds()
|
||||||
|
}
|
||||||
|
out5, err5 := psc5.callGetRecordsAPI(&gri)
|
||||||
|
assert.Nil(t, err5)
|
||||||
|
assert.Equal(t, &ret5, out5)
|
||||||
|
m5.AssertExpectations(t)
|
||||||
|
if checkSleepVal2 != 0 {
|
||||||
|
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for correct cool off period with 8mb in .2 seconds with 6mb remaining
|
||||||
|
testTime3 := time.Now()
|
||||||
|
m6 := MockKinesisSubscriberGetter{}
|
||||||
|
ret6 := kinesis.GetRecordsOutput{}
|
||||||
|
m6.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret6, nil)
|
||||||
|
psc6 := PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{kc: &m6},
|
||||||
|
callsLeft: 2,
|
||||||
|
bytesRead: MaxBytesPerSecond * 4,
|
||||||
|
lastCheckTime: testTime3,
|
||||||
|
remBytes: MaxBytes * 3,
|
||||||
|
}
|
||||||
|
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||||
|
return 3 * time.Second
|
||||||
|
}
|
||||||
|
rateLimitTimeNow = func() time.Time {
|
||||||
|
return testTime3.Add(time.Second / 5)
|
||||||
|
}
|
||||||
|
checkSleepVal3 := 0.0
|
||||||
|
rateLimitSleep = func(d time.Duration) {
|
||||||
|
checkSleepVal3 = d.Seconds()
|
||||||
|
}
|
||||||
|
out6, err6 := psc6.callGetRecordsAPI(&gri)
|
||||||
|
assert.Nil(t, err6)
|
||||||
|
assert.Equal(t, &ret6, out6)
|
||||||
|
m5.AssertExpectations(t)
|
||||||
|
if checkSleepVal3 != 4 {
|
||||||
|
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// restore original func
|
||||||
|
rateLimitTimeNow = time.Now
|
||||||
|
rateLimitTimeSince = time.Since
|
||||||
|
rateLimitSleep = time.Sleep
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockKinesisSubscriberGetter struct {
|
type MockKinesisSubscriberGetter struct {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue