vmware-go-kcl-v2/clientlibrary/worker/polling-shard-consumer_test.go
Shiva Pentakota 66006caf89 fix: add getRecords TPS rate limiting
Signed-off-by: Shiva Pentakota <spentakota@vmware.com>
2023-01-24 11:56:29 -08:00

80 lines
3 KiB
Go

/*
* Copyright (c) 2023 VMware, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
* associated documentation files (the "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is furnished to do
* so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial
* portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
* NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
package worker
import (
"context"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestCallGetRecordsAPI(t *testing.T) {
// basic happy path
m1 := MockKinesisSubscriberGetter{}
ret := kinesis.GetRecordsOutput{}
m1.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret, nil)
psc := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m1},
}
gri := kinesis.GetRecordsInput{
ShardIterator: aws.String("shard-iterator-01"),
}
out, err := psc.callGetRecordsAPI(&gri)
assert.Nil(t, err)
assert.Equal(t, &ret, out)
m1.AssertExpectations(t)
// check that localTPSExceededError is thrown when trying more than 5 TPS
m2 := MockKinesisSubscriberGetter{}
psc2 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m2},
callsLeft: 0,
}
rateLimitTimeSince = func(t time.Time) time.Duration {
return 500 * time.Millisecond
}
out2, err2 := psc2.callGetRecordsAPI(&gri)
assert.Nil(t, out2)
assert.ErrorIs(t, err2, localTPSExceededError)
m2.AssertExpectations(t)
}
type MockKinesisSubscriberGetter struct {
mock.Mock
}
func (m *MockKinesisSubscriberGetter) GetRecords(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
ret := m.Called(ctx, params, optFns)
return ret.Get(0).(*kinesis.GetRecordsOutput), ret.Error(1)
}
func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return nil, nil
}
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
return nil, nil
}