fix: add check for GetRecords error within callGetRecordsAPI

Signed-off-by: Shiva Pentakota <spentakota@vmware.com>
This commit is contained in:
Shiva Pentakota 2023-02-01 08:00:49 -08:00
parent 42881449ce
commit 04c5062ace
2 changed files with 28 additions and 1 deletions

View file

@ -287,9 +287,13 @@ func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput)
if sc.callsLeft < 1 {
return nil, 0, localTPSExceededError
}
getResp, err := sc.kc.GetRecords(context.TODO(), gri)
sc.callsLeft--
if err != nil {
return getResp, 0, err
}
// Calculate size of records from read transaction
sc.bytesRead = 0
for _, record := range getResp.Records {

View file

@ -21,6 +21,7 @@ package worker
import (
"context"
"errors"
"testing"
"time"
@ -30,6 +31,10 @@ import (
"github.com/stretchr/testify/mock"
)
var (
testGetRecordsError = errors.New("GetRecords Error")
)
func TestCallGetRecordsAPI(t *testing.T) {
// basic happy path
m1 := MockKinesisSubscriberGetter{}
@ -150,6 +155,24 @@ func TestCallGetRecordsAPI(t *testing.T) {
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
}
// case where getRecords throws error
m7 := MockKinesisSubscriberGetter{}
ret7 := kinesis.GetRecordsOutput{Records: nil}
m7.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret7, testGetRecordsError)
psc7 := PollingShardConsumer{
commonShardConsumer: commonShardConsumer{kc: &m7},
callsLeft: 2,
bytesRead: 0,
}
rateLimitTimeSince = func(t time.Time) time.Duration {
return 2 * time.Second
}
out7, checkSleepVal7, err7 := psc7.callGetRecordsAPI(&gri)
assert.Equal(t, err7, testGetRecordsError)
assert.Equal(t, checkSleepVal7, 0)
assert.Equal(t, out7, &ret7)
m7.AssertExpectations(t)
// restore original func
rateLimitTimeNow = time.Now
rateLimitTimeSince = time.Since