diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index b96b26d..28c1f6a 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -287,9 +287,13 @@ func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) if sc.callsLeft < 1 { return nil, 0, localTPSExceededError } - getResp, err := sc.kc.GetRecords(context.TODO(), gri) sc.callsLeft-- + + if err != nil { + return getResp, 0, err + } + // Calculate size of records from read transaction sc.bytesRead = 0 for _, record := range getResp.Records { diff --git a/clientlibrary/worker/polling-shard-consumer_test.go b/clientlibrary/worker/polling-shard-consumer_test.go index 68dffd0..c94859d 100644 --- a/clientlibrary/worker/polling-shard-consumer_test.go +++ b/clientlibrary/worker/polling-shard-consumer_test.go @@ -21,6 +21,7 @@ package worker import ( "context" + "errors" "testing" "time" @@ -30,6 +31,10 @@ import ( "github.com/stretchr/testify/mock" ) +var ( + testGetRecordsError = errors.New("GetRecords Error") +) + func TestCallGetRecordsAPI(t *testing.T) { // basic happy path m1 := MockKinesisSubscriberGetter{} @@ -150,6 +155,24 @@ func TestCallGetRecordsAPI(t *testing.T) { t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4) } + // case where getRecords throws error + m7 := MockKinesisSubscriberGetter{} + ret7 := kinesis.GetRecordsOutput{Records: nil} + m7.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret7, testGetRecordsError) + psc7 := PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{kc: &m7}, + callsLeft: 2, + bytesRead: 0, + } + rateLimitTimeSince = func(t time.Time) time.Duration { + return 2 * time.Second + } + out7, checkSleepVal7, err7 := psc7.callGetRecordsAPI(&gri) + assert.Equal(t, err7, testGetRecordsError) + assert.Equal(t, checkSleepVal7, 0) + assert.Equal(t, out7, &ret7) + m7.AssertExpectations(t) + // restore original func rateLimitTimeNow = time.Now rateLimitTimeSince = time.Since