Merge 486651702f into f6e79f1a2d
This commit is contained in:
commit
8df326e926
2 changed files with 334 additions and 10 deletions
|
|
@ -145,6 +145,16 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
leaseRenewalErrChan <- sc.renewLease(ctx)
|
leaseRenewalErrChan <- sc.renewLease(ctx)
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
|
select {
|
||||||
|
case <-*sc.stop:
|
||||||
|
shutdownInput := &kcl.ShutdownInput{ShutdownReason: kcl.REQUESTED, Checkpointer: recordCheckpointer}
|
||||||
|
sc.recordProcessor.Shutdown(shutdownInput)
|
||||||
|
return nil
|
||||||
|
case leaseRenewalErr := <-leaseRenewalErrChan:
|
||||||
|
return leaseRenewalErr
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
getRecordsStartTime := time.Now()
|
getRecordsStartTime := time.Now()
|
||||||
|
|
||||||
log.Debugf("Trying to read %d record from iterator: %v", sc.kclConfig.MaxRecords, aws.ToString(shardIterator))
|
log.Debugf("Trying to read %d record from iterator: %v", sc.kclConfig.MaxRecords, aws.ToString(shardIterator))
|
||||||
|
|
@ -226,15 +236,6 @@ func (sc *PollingShardConsumer) getRecords() error {
|
||||||
time.Sleep(time.Duration(sc.kclConfig.IdleTimeBetweenReadsInMillis) * time.Millisecond)
|
time.Sleep(time.Duration(sc.kclConfig.IdleTimeBetweenReadsInMillis) * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
|
||||||
case <-*sc.stop:
|
|
||||||
shutdownInput := &kcl.ShutdownInput{ShutdownReason: kcl.REQUESTED, Checkpointer: recordCheckpointer}
|
|
||||||
sc.recordProcessor.Shutdown(shutdownInput)
|
|
||||||
return nil
|
|
||||||
case leaseRenewalErr := <-leaseRenewalErrChan:
|
|
||||||
return leaseRenewalErr
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,17 +22,26 @@ package worker
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
chk "github.com/vmware/vmware-go-kcl-v2/clientlibrary/checkpoint"
|
||||||
|
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/config"
|
||||||
|
kcl "github.com/vmware/vmware-go-kcl-v2/clientlibrary/interfaces"
|
||||||
|
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics"
|
||||||
|
par "github.com/vmware/vmware-go-kcl-v2/clientlibrary/partition"
|
||||||
|
"github.com/vmware/vmware-go-kcl-v2/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testGetRecordsError = errors.New("GetRecords Error")
|
testGetRecordsError = errors.New("GetRecords Error")
|
||||||
|
getLeaseTestFailure = errors.New("GetLease test failure")
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCallGetRecordsAPI(t *testing.T) {
|
func TestCallGetRecordsAPI(t *testing.T) {
|
||||||
|
|
@ -194,7 +203,8 @@ func (m *MockKinesisSubscriberGetter) GetRecords(ctx context.Context, params *ki
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return nil, nil
|
ret := m.Called(ctx, params, optFns)
|
||||||
|
return ret.Get(0).(*kinesis.GetShardIteratorOutput), ret.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
|
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
|
||||||
|
|
@ -353,3 +363,316 @@ func TestPollingShardConsumer_checkCoolOffPeriod(t *testing.T) {
|
||||||
// restore original time.Now
|
// restore original time.Now
|
||||||
rateLimitTimeNow = time.Now
|
rateLimitTimeNow = time.Now
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPollingShardConsumer_renewLease(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
checkpointer chk.Checkpointer
|
||||||
|
kclConfig *config.KinesisClientLibConfiguration
|
||||||
|
mService metrics.MonitoringService
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
testMillis time.Duration
|
||||||
|
expRenewalCalls int
|
||||||
|
expRenewals int
|
||||||
|
expErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"renew once",
|
||||||
|
fields{
|
||||||
|
&mockCheckpointer{},
|
||||||
|
&config.KinesisClientLibConfiguration{
|
||||||
|
LeaseRefreshWaitTime: 10,
|
||||||
|
},
|
||||||
|
&mockMetrics{},
|
||||||
|
},
|
||||||
|
15,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"renew some",
|
||||||
|
fields{
|
||||||
|
&mockCheckpointer{},
|
||||||
|
&config.KinesisClientLibConfiguration{
|
||||||
|
LeaseRefreshWaitTime: 50,
|
||||||
|
},
|
||||||
|
&mockMetrics{},
|
||||||
|
},
|
||||||
|
50*5 + 10,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"renew twice every 2.5 seconds",
|
||||||
|
fields{
|
||||||
|
&mockCheckpointer{},
|
||||||
|
&config.KinesisClientLibConfiguration{
|
||||||
|
LeaseRefreshWaitTime: 2500,
|
||||||
|
},
|
||||||
|
&mockMetrics{},
|
||||||
|
},
|
||||||
|
5100,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"lease error",
|
||||||
|
fields{
|
||||||
|
&mockCheckpointer{fail: true},
|
||||||
|
&config.KinesisClientLibConfiguration{
|
||||||
|
LeaseRefreshWaitTime: 500,
|
||||||
|
},
|
||||||
|
&mockMetrics{},
|
||||||
|
},
|
||||||
|
1100,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
getLeaseTestFailure,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
sc := &PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{
|
||||||
|
shard: &par.ShardStatus{},
|
||||||
|
checkpointer: tt.fields.checkpointer,
|
||||||
|
kclConfig: tt.fields.kclConfig,
|
||||||
|
},
|
||||||
|
mService: tt.fields.mService,
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
leaseRenewalErrChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
leaseRenewalErrChan <- sc.renewLease(ctx)
|
||||||
|
}()
|
||||||
|
time.Sleep(tt.testMillis * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
err := <-leaseRenewalErrChan
|
||||||
|
assert.Equal(t, tt.expErr, err)
|
||||||
|
assert.Equal(t, tt.expRenewalCalls, sc.checkpointer.(*mockCheckpointer).getLeaseCalledTimes)
|
||||||
|
assert.Equal(t, tt.expRenewals, sc.mService.(*mockMetrics).leaseRenewedCalledTimes)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPollingShardConsumer_getRecordsRenewLease(t *testing.T) {
|
||||||
|
log := logger.GetDefaultLogger()
|
||||||
|
type fields struct {
|
||||||
|
checkpointer chk.Checkpointer
|
||||||
|
kclConfig *config.KinesisClientLibConfiguration
|
||||||
|
mService metrics.MonitoringService
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
|
||||||
|
// testMillis must be at least 200ms or you'll trigger the localTPSExceededError
|
||||||
|
testMillis time.Duration
|
||||||
|
expRenewalCalls int
|
||||||
|
expRenewals int
|
||||||
|
shardClosed bool
|
||||||
|
expErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"renew once",
|
||||||
|
fields{
|
||||||
|
&mockCheckpointer{},
|
||||||
|
&config.KinesisClientLibConfiguration{
|
||||||
|
LeaseRefreshWaitTime: 200,
|
||||||
|
Logger: log,
|
||||||
|
InitialPositionInStream: config.LATEST,
|
||||||
|
},
|
||||||
|
&mockMetrics{},
|
||||||
|
},
|
||||||
|
250,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"renew some",
|
||||||
|
fields{
|
||||||
|
&mockCheckpointer{},
|
||||||
|
&config.KinesisClientLibConfiguration{
|
||||||
|
LeaseRefreshWaitTime: 50,
|
||||||
|
Logger: log,
|
||||||
|
InitialPositionInStream: config.LATEST,
|
||||||
|
},
|
||||||
|
&mockMetrics{},
|
||||||
|
},
|
||||||
|
50*5 + 10,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"renew twice every 2.5 seconds",
|
||||||
|
fields{
|
||||||
|
&mockCheckpointer{},
|
||||||
|
&config.KinesisClientLibConfiguration{
|
||||||
|
LeaseRefreshWaitTime: 2500,
|
||||||
|
Logger: log,
|
||||||
|
InitialPositionInStream: config.LATEST,
|
||||||
|
},
|
||||||
|
&mockMetrics{},
|
||||||
|
},
|
||||||
|
5100,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
false,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"lease error",
|
||||||
|
fields{
|
||||||
|
&mockCheckpointer{fail: true},
|
||||||
|
&config.KinesisClientLibConfiguration{
|
||||||
|
LeaseRefreshWaitTime: 500,
|
||||||
|
Logger: log,
|
||||||
|
InitialPositionInStream: config.LATEST,
|
||||||
|
},
|
||||||
|
&mockMetrics{},
|
||||||
|
},
|
||||||
|
1100,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
getLeaseTestFailure,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
iterator := "test-iterator"
|
||||||
|
nextIt := "test-next-iterator"
|
||||||
|
millisBehind := int64(0)
|
||||||
|
stopChan := make(chan struct{})
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mk := MockKinesisSubscriberGetter{}
|
||||||
|
gro := kinesis.GetRecordsOutput{
|
||||||
|
Records: []types.Record{
|
||||||
|
{
|
||||||
|
Data: []byte{},
|
||||||
|
PartitionKey: new(string),
|
||||||
|
SequenceNumber: new(string),
|
||||||
|
ApproximateArrivalTimestamp: &time.Time{},
|
||||||
|
EncryptionType: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MillisBehindLatest: &millisBehind,
|
||||||
|
}
|
||||||
|
if !tt.shardClosed {
|
||||||
|
gro.NextShardIterator = &nextIt
|
||||||
|
}
|
||||||
|
mk.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&gro, nil)
|
||||||
|
mk.On("GetShardIterator", mock.Anything, mock.Anything, mock.Anything).Return(&kinesis.GetShardIteratorOutput{ShardIterator: &iterator}, nil)
|
||||||
|
rp := mockRecordProcessor{
|
||||||
|
processDurationMillis: tt.testMillis,
|
||||||
|
}
|
||||||
|
sc := &PollingShardConsumer{
|
||||||
|
commonShardConsumer: commonShardConsumer{
|
||||||
|
shard: &par.ShardStatus{
|
||||||
|
ID: "test-shard-id",
|
||||||
|
Mux: &sync.RWMutex{},
|
||||||
|
},
|
||||||
|
checkpointer: tt.fields.checkpointer,
|
||||||
|
kclConfig: tt.fields.kclConfig,
|
||||||
|
kc: &mk,
|
||||||
|
recordProcessor: &rp,
|
||||||
|
mService: tt.fields.mService,
|
||||||
|
},
|
||||||
|
stop: &stopChan,
|
||||||
|
mService: tt.fields.mService,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the stop signal a little before the total time it should
|
||||||
|
// take to get records and process them. This prevents test time
|
||||||
|
// errors due to the threads running longer than the test case
|
||||||
|
// expects.
|
||||||
|
go func() {
|
||||||
|
time.Sleep((tt.testMillis - 1) * time.Millisecond)
|
||||||
|
stopChan <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := sc.getRecords()
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expErr, err)
|
||||||
|
assert.Equal(t, tt.expRenewalCalls, sc.checkpointer.(*mockCheckpointer).readGetLeaseCalledTimes())
|
||||||
|
assert.Equal(t, tt.expRenewals, sc.mService.(*mockMetrics).readLeaseRenewedCalledTimes())
|
||||||
|
mk.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockCheckpointer struct {
|
||||||
|
getLeaseCalledTimes int
|
||||||
|
gLCTMu sync.Mutex
|
||||||
|
fail bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCheckpointer) readGetLeaseCalledTimes() int {
|
||||||
|
m.gLCTMu.Lock()
|
||||||
|
defer m.gLCTMu.Unlock()
|
||||||
|
return m.getLeaseCalledTimes
|
||||||
|
}
|
||||||
|
func (m *mockCheckpointer) Init() error { return nil }
|
||||||
|
func (m *mockCheckpointer) GetLease(*par.ShardStatus, string) error {
|
||||||
|
m.gLCTMu.Lock()
|
||||||
|
defer m.gLCTMu.Unlock()
|
||||||
|
m.getLeaseCalledTimes++
|
||||||
|
if m.fail {
|
||||||
|
return getLeaseTestFailure
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockCheckpointer) CheckpointSequence(*par.ShardStatus) error { return nil }
|
||||||
|
func (m *mockCheckpointer) FetchCheckpoint(*par.ShardStatus) error { return nil }
|
||||||
|
func (m *mockCheckpointer) RemoveLeaseInfo(string) error { return nil }
|
||||||
|
func (m *mockCheckpointer) RemoveLeaseOwner(string) error { return nil }
|
||||||
|
func (m *mockCheckpointer) GetLeaseOwner(string) (string, error) { return "", nil }
|
||||||
|
func (m *mockCheckpointer) ListActiveWorkers(map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) {
|
||||||
|
return map[string][]*par.ShardStatus{}, nil
|
||||||
|
}
|
||||||
|
func (m *mockCheckpointer) ClaimShard(*par.ShardStatus, string) error { return nil }
|
||||||
|
|
||||||
|
type mockRecordProcessor struct {
|
||||||
|
processDurationMillis time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockRecordProcessor) Initialize(initializationInput *kcl.InitializationInput) {}
|
||||||
|
func (m mockRecordProcessor) ProcessRecords(processRecordsInput *kcl.ProcessRecordsInput) {
|
||||||
|
time.Sleep(time.Millisecond * m.processDurationMillis)
|
||||||
|
}
|
||||||
|
func (m mockRecordProcessor) Shutdown(shutdownInput *kcl.ShutdownInput) {}
|
||||||
|
|
||||||
|
type mockMetrics struct {
|
||||||
|
leaseRenewedCalledTimes int
|
||||||
|
lRCTMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMetrics) readLeaseRenewedCalledTimes() int {
|
||||||
|
m.lRCTMu.Lock()
|
||||||
|
defer m.lRCTMu.Unlock()
|
||||||
|
return m.leaseRenewedCalledTimes
|
||||||
|
}
|
||||||
|
func (m *mockMetrics) Init(appName, streamName, workerID string) error { return nil }
|
||||||
|
func (m *mockMetrics) Start() error { return nil }
|
||||||
|
func (m *mockMetrics) IncrRecordsProcessed(shard string, count int) {}
|
||||||
|
func (m *mockMetrics) IncrBytesProcessed(shard string, count int64) {}
|
||||||
|
func (m *mockMetrics) MillisBehindLatest(shard string, milliSeconds float64) {}
|
||||||
|
func (m *mockMetrics) DeleteMetricMillisBehindLatest(shard string) {}
|
||||||
|
func (m *mockMetrics) LeaseGained(shard string) {}
|
||||||
|
func (m *mockMetrics) LeaseLost(shard string) {}
|
||||||
|
func (m *mockMetrics) LeaseRenewed(shard string) {
|
||||||
|
m.lRCTMu.Lock()
|
||||||
|
defer m.lRCTMu.Unlock()
|
||||||
|
m.leaseRenewedCalledTimes++
|
||||||
|
}
|
||||||
|
func (m *mockMetrics) RecordGetRecordsTime(shard string, time float64) {}
|
||||||
|
func (m *mockMetrics) RecordProcessRecordsTime(shard string, time float64) {}
|
||||||
|
func (m *mockMetrics) Shutdown() {}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue