Merge pull request #23 from vmware/spentakota_filNilError
fix: add check for GetRecords error within callGetRecordsAPI
This commit is contained in:
commit
fb17ec8bc6
2 changed files with 28 additions and 1 deletions
|
|
@ -287,9 +287,13 @@ func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput)
|
||||||
if sc.callsLeft < 1 {
|
if sc.callsLeft < 1 {
|
||||||
return nil, 0, localTPSExceededError
|
return nil, 0, localTPSExceededError
|
||||||
}
|
}
|
||||||
|
|
||||||
getResp, err := sc.kc.GetRecords(context.TODO(), gri)
|
getResp, err := sc.kc.GetRecords(context.TODO(), gri)
|
||||||
sc.callsLeft--
|
sc.callsLeft--
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return getResp, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate size of records from read transaction
|
// Calculate size of records from read transaction
|
||||||
sc.bytesRead = 0
|
sc.bytesRead = 0
|
||||||
for _, record := range getResp.Records {
|
for _, record := range getResp.Records {
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ package worker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -30,6 +31,10 @@ import (
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testGetRecordsError = errors.New("GetRecords Error")
|
||||||
|
)
|
||||||
|
|
||||||
func TestCallGetRecordsAPI(t *testing.T) {
|
func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
// basic happy path
|
// basic happy path
|
||||||
m1 := MockKinesisSubscriberGetter{}
|
m1 := MockKinesisSubscriberGetter{}
|
||||||
|
|
@ -150,6 +155,24 @@ func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
|
t.Errorf("Incorrect Cool Off Period: %v", checkSleepVal4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// case where getRecords throws error
|
||||||
|
m7 := MockKinesisSubscriberGetter{}
|
||||||
|
ret7 := kinesis.GetRecordsOutput{Records: nil}
|
||||||
|
m7.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&ret7, testGetRecordsError)
|
||||||
|
psc7 := PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{kc: &m7},
|
||||||
|
callsLeft: 2,
|
||||||
|
bytesRead: 0,
|
||||||
|
}
|
||||||
|
rateLimitTimeSince = func(t time.Time) time.Duration {
|
||||||
|
return 2 * time.Second
|
||||||
|
}
|
||||||
|
out7, checkSleepVal7, err7 := psc7.callGetRecordsAPI(&gri)
|
||||||
|
assert.Equal(t, err7, testGetRecordsError)
|
||||||
|
assert.Equal(t, checkSleepVal7, 0)
|
||||||
|
assert.Equal(t, out7, &ret7)
|
||||||
|
m7.AssertExpectations(t)
|
||||||
|
|
||||||
// restore original func
|
// restore original func
|
||||||
rateLimitTimeNow = time.Now
|
rateLimitTimeNow = time.Now
|
||||||
rateLimitTimeSince = time.Since
|
rateLimitTimeSince = time.Since
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue