This commit is contained in:
vmwjc 2024-06-20 04:08:02 +00:00 committed by GitHub
commit 8df326e926
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 334 additions and 10 deletions

View file

@ -145,6 +145,16 @@ func (sc *PollingShardConsumer) getRecords() error {
leaseRenewalErrChan <- sc.renewLease(ctx)
}()
for {
select {
case <-*sc.stop:
shutdownInput := &kcl.ShutdownInput{ShutdownReason: kcl.REQUESTED, Checkpointer: recordCheckpointer}
sc.recordProcessor.Shutdown(shutdownInput)
return nil
case leaseRenewalErr := <-leaseRenewalErrChan:
return leaseRenewalErr
default:
}
getRecordsStartTime := time.Now()
log.Debugf("Trying to read %d record from iterator: %v", sc.kclConfig.MaxRecords, aws.ToString(shardIterator))
@ -226,15 +236,6 @@ func (sc *PollingShardConsumer) getRecords() error {
time.Sleep(time.Duration(sc.kclConfig.IdleTimeBetweenReadsInMillis) * time.Millisecond)
}
select {
case <-*sc.stop:
shutdownInput := &kcl.ShutdownInput{ShutdownReason: kcl.REQUESTED, Checkpointer: recordCheckpointer}
sc.recordProcessor.Shutdown(shutdownInput)
return nil
case leaseRenewalErr := <-leaseRenewalErrChan:
return leaseRenewalErr
default:
}
}
}

View file

@ -22,17 +22,26 @@ package worker
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
chk "github.com/vmware/vmware-go-kcl-v2/clientlibrary/checkpoint"
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/config"
kcl "github.com/vmware/vmware-go-kcl-v2/clientlibrary/interfaces"
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics"
par "github.com/vmware/vmware-go-kcl-v2/clientlibrary/partition"
"github.com/vmware/vmware-go-kcl-v2/logger"
)
var (
testGetRecordsError = errors.New("GetRecords Error")
getLeaseTestFailure = errors.New("GetLease test failure")
)
func TestCallGetRecordsAPI(t *testing.T) {
@ -194,7 +203,8 @@ func (m *MockKinesisSubscriberGetter) GetRecords(ctx context.Context, params *ki
}
func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return nil, nil
ret := m.Called(ctx, params, optFns)
return ret.Get(0).(*kinesis.GetShardIteratorOutput), ret.Error(1)
}
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
@ -353,3 +363,316 @@ func TestPollingShardConsumer_checkCoolOffPeriod(t *testing.T) {
// restore original time.Now
rateLimitTimeNow = time.Now
}
func TestPollingShardConsumer_renewLease(t *testing.T) {
type fields struct {
checkpointer chk.Checkpointer
kclConfig *config.KinesisClientLibConfiguration
mService metrics.MonitoringService
}
tests := []struct {
name string
fields fields
testMillis time.Duration
expRenewalCalls int
expRenewals int
expErr error
}{
{
"renew once",
fields{
&mockCheckpointer{},
&config.KinesisClientLibConfiguration{
LeaseRefreshWaitTime: 10,
},
&mockMetrics{},
},
15,
1,
1,
nil,
},
{
"renew some",
fields{
&mockCheckpointer{},
&config.KinesisClientLibConfiguration{
LeaseRefreshWaitTime: 50,
},
&mockMetrics{},
},
50*5 + 10,
5,
5,
nil,
},
{
"renew twice every 2.5 seconds",
fields{
&mockCheckpointer{},
&config.KinesisClientLibConfiguration{
LeaseRefreshWaitTime: 2500,
},
&mockMetrics{},
},
5100,
2,
2,
nil,
},
{
"lease error",
fields{
&mockCheckpointer{fail: true},
&config.KinesisClientLibConfiguration{
LeaseRefreshWaitTime: 500,
},
&mockMetrics{},
},
1100,
1,
0,
getLeaseTestFailure,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sc := &PollingShardConsumer{
commonShardConsumer: commonShardConsumer{
shard: &par.ShardStatus{},
checkpointer: tt.fields.checkpointer,
kclConfig: tt.fields.kclConfig,
},
mService: tt.fields.mService,
}
ctx, cancel := context.WithCancel(context.Background())
leaseRenewalErrChan := make(chan error, 1)
go func() {
leaseRenewalErrChan <- sc.renewLease(ctx)
}()
time.Sleep(tt.testMillis * time.Millisecond)
cancel()
err := <-leaseRenewalErrChan
assert.Equal(t, tt.expErr, err)
assert.Equal(t, tt.expRenewalCalls, sc.checkpointer.(*mockCheckpointer).getLeaseCalledTimes)
assert.Equal(t, tt.expRenewals, sc.mService.(*mockMetrics).leaseRenewedCalledTimes)
})
}
}
func TestPollingShardConsumer_getRecordsRenewLease(t *testing.T) {
log := logger.GetDefaultLogger()
type fields struct {
checkpointer chk.Checkpointer
kclConfig *config.KinesisClientLibConfiguration
mService metrics.MonitoringService
}
tests := []struct {
name string
fields fields
// testMillis must be at least 200ms or you'll trigger the localTPSExceededError
testMillis time.Duration
expRenewalCalls int
expRenewals int
shardClosed bool
expErr error
}{
{
"renew once",
fields{
&mockCheckpointer{},
&config.KinesisClientLibConfiguration{
LeaseRefreshWaitTime: 200,
Logger: log,
InitialPositionInStream: config.LATEST,
},
&mockMetrics{},
},
250,
1,
1,
false,
nil,
},
{
"renew some",
fields{
&mockCheckpointer{},
&config.KinesisClientLibConfiguration{
LeaseRefreshWaitTime: 50,
Logger: log,
InitialPositionInStream: config.LATEST,
},
&mockMetrics{},
},
50*5 + 10,
5,
5,
false,
nil,
},
{
"renew twice every 2.5 seconds",
fields{
&mockCheckpointer{},
&config.KinesisClientLibConfiguration{
LeaseRefreshWaitTime: 2500,
Logger: log,
InitialPositionInStream: config.LATEST,
},
&mockMetrics{},
},
5100,
2,
2,
false,
nil,
},
{
"lease error",
fields{
&mockCheckpointer{fail: true},
&config.KinesisClientLibConfiguration{
LeaseRefreshWaitTime: 500,
Logger: log,
InitialPositionInStream: config.LATEST,
},
&mockMetrics{},
},
1100,
1,
0,
false,
getLeaseTestFailure,
},
}
iterator := "test-iterator"
nextIt := "test-next-iterator"
millisBehind := int64(0)
stopChan := make(chan struct{})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mk := MockKinesisSubscriberGetter{}
gro := kinesis.GetRecordsOutput{
Records: []types.Record{
{
Data: []byte{},
PartitionKey: new(string),
SequenceNumber: new(string),
ApproximateArrivalTimestamp: &time.Time{},
EncryptionType: "",
},
},
MillisBehindLatest: &millisBehind,
}
if !tt.shardClosed {
gro.NextShardIterator = &nextIt
}
mk.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&gro, nil)
mk.On("GetShardIterator", mock.Anything, mock.Anything, mock.Anything).Return(&kinesis.GetShardIteratorOutput{ShardIterator: &iterator}, nil)
rp := mockRecordProcessor{
processDurationMillis: tt.testMillis,
}
sc := &PollingShardConsumer{
commonShardConsumer: commonShardConsumer{
shard: &par.ShardStatus{
ID: "test-shard-id",
Mux: &sync.RWMutex{},
},
checkpointer: tt.fields.checkpointer,
kclConfig: tt.fields.kclConfig,
kc: &mk,
recordProcessor: &rp,
mService: tt.fields.mService,
},
stop: &stopChan,
mService: tt.fields.mService,
}
// Send the stop signal a little before the total time it should
// take to get records and process them. This prevents test time
// errors due to the threads running longer than the test case
// expects.
go func() {
time.Sleep((tt.testMillis - 1) * time.Millisecond)
stopChan <- struct{}{}
}()
err := sc.getRecords()
assert.Equal(t, tt.expErr, err)
assert.Equal(t, tt.expRenewalCalls, sc.checkpointer.(*mockCheckpointer).readGetLeaseCalledTimes())
assert.Equal(t, tt.expRenewals, sc.mService.(*mockMetrics).readLeaseRenewedCalledTimes())
mk.AssertExpectations(t)
})
}
}
type mockCheckpointer struct {
getLeaseCalledTimes int
gLCTMu sync.Mutex
fail bool
}
func (m *mockCheckpointer) readGetLeaseCalledTimes() int {
m.gLCTMu.Lock()
defer m.gLCTMu.Unlock()
return m.getLeaseCalledTimes
}
func (m *mockCheckpointer) Init() error { return nil }
func (m *mockCheckpointer) GetLease(*par.ShardStatus, string) error {
m.gLCTMu.Lock()
defer m.gLCTMu.Unlock()
m.getLeaseCalledTimes++
if m.fail {
return getLeaseTestFailure
}
return nil
}
func (m *mockCheckpointer) CheckpointSequence(*par.ShardStatus) error { return nil }
func (m *mockCheckpointer) FetchCheckpoint(*par.ShardStatus) error { return nil }
func (m *mockCheckpointer) RemoveLeaseInfo(string) error { return nil }
func (m *mockCheckpointer) RemoveLeaseOwner(string) error { return nil }
func (m *mockCheckpointer) GetLeaseOwner(string) (string, error) { return "", nil }
func (m *mockCheckpointer) ListActiveWorkers(map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) {
return map[string][]*par.ShardStatus{}, nil
}
func (m *mockCheckpointer) ClaimShard(*par.ShardStatus, string) error { return nil }
type mockRecordProcessor struct {
processDurationMillis time.Duration
}
func (m mockRecordProcessor) Initialize(initializationInput *kcl.InitializationInput) {}
func (m mockRecordProcessor) ProcessRecords(processRecordsInput *kcl.ProcessRecordsInput) {
time.Sleep(time.Millisecond * m.processDurationMillis)
}
func (m mockRecordProcessor) Shutdown(shutdownInput *kcl.ShutdownInput) {}
type mockMetrics struct {
leaseRenewedCalledTimes int
lRCTMu sync.Mutex
}
func (m *mockMetrics) readLeaseRenewedCalledTimes() int {
m.lRCTMu.Lock()
defer m.lRCTMu.Unlock()
return m.leaseRenewedCalledTimes
}
func (m *mockMetrics) Init(appName, streamName, workerID string) error { return nil }
func (m *mockMetrics) Start() error { return nil }
func (m *mockMetrics) IncrRecordsProcessed(shard string, count int) {}
func (m *mockMetrics) IncrBytesProcessed(shard string, count int64) {}
func (m *mockMetrics) MillisBehindLatest(shard string, milliSeconds float64) {}
func (m *mockMetrics) DeleteMetricMillisBehindLatest(shard string) {}
func (m *mockMetrics) LeaseGained(shard string) {}
func (m *mockMetrics) LeaseLost(shard string) {}
func (m *mockMetrics) LeaseRenewed(shard string) {
m.lRCTMu.Lock()
defer m.lRCTMu.Unlock()
m.leaseRenewedCalledTimes++
}
func (m *mockMetrics) RecordGetRecordsTime(shard string, time float64) {}
func (m *mockMetrics) RecordProcessRecordsTime(shard string, time float64) {}
func (m *mockMetrics) Shutdown() {}