From 486651702f1de09a6b925933b811854d0bfc333a Mon Sep 17 00:00:00 2001 From: John Calixto Date: Sat, 8 Apr 2023 11:24:33 -0700 Subject: [PATCH] fix: Move getRecords stop condition checking to top of loop Added tests to validate coordination of the lease renewal with the record getting and processing routines. This brought to light the value of checking the loop's "stop conditions" early in the loop instead of at the end. Signed-off-by: John Calixto --- .../worker/polling-shard-consumer.go | 19 +- .../worker/polling-shard-consumer_test.go | 220 ++++++++++++++++-- 2 files changed, 210 insertions(+), 29 deletions(-) diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index 3829850..f3e23f8 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -145,6 +145,16 @@ func (sc *PollingShardConsumer) getRecords() error { leaseRenewalErrChan <- sc.renewLease(ctx) }() for { + select { + case <-*sc.stop: + shutdownInput := &kcl.ShutdownInput{ShutdownReason: kcl.REQUESTED, Checkpointer: recordCheckpointer} + sc.recordProcessor.Shutdown(shutdownInput) + return nil + case leaseRenewalErr := <-leaseRenewalErrChan: + return leaseRenewalErr + default: + } + getRecordsStartTime := time.Now() log.Debugf("Trying to read %d record from iterator: %v", sc.kclConfig.MaxRecords, aws.ToString(shardIterator)) @@ -226,15 +236,6 @@ func (sc *PollingShardConsumer) getRecords() error { time.Sleep(time.Duration(sc.kclConfig.IdleTimeBetweenReadsInMillis) * time.Millisecond) } - select { - case <-*sc.stop: - shutdownInput := &kcl.ShutdownInput{ShutdownReason: kcl.REQUESTED, Checkpointer: recordCheckpointer} - sc.recordProcessor.Shutdown(shutdownInput) - return nil - case leaseRenewalErr := <-leaseRenewalErrChan: - return leaseRenewalErr - default: - } } } diff --git a/clientlibrary/worker/polling-shard-consumer_test.go b/clientlibrary/worker/polling-shard-consumer_test.go index 7a8d09e..16f3454 100644 --- a/clientlibrary/worker/polling-shard-consumer_test.go +++ b/clientlibrary/worker/polling-shard-consumer_test.go @@ -22,17 +22,21 @@ package worker import ( "context" "errors" + "sync" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" chk "github.com/vmware/vmware-go-kcl-v2/clientlibrary/checkpoint" "github.com/vmware/vmware-go-kcl-v2/clientlibrary/config" + kcl "github.com/vmware/vmware-go-kcl-v2/clientlibrary/interfaces" "github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics" par "github.com/vmware/vmware-go-kcl-v2/clientlibrary/partition" + "github.com/vmware/vmware-go-kcl-v2/logger" ) var ( @@ -199,7 +203,8 @@ func (m *MockKinesisSubscriberGetter) GetRecords(ctx context.Context, params *ki } func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { - return nil, nil + ret := m.Called(ctx, params, optFns) + return ret.Get(0).(*kinesis.GetShardIteratorOutput), ret.Error(1) } func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) { @@ -455,44 +460,219 @@ func TestPollingShardConsumer_renewLease(t *testing.T) { } } +func TestPollingShardConsumer_getRecordsRenewLease(t *testing.T) { + log := logger.GetDefaultLogger() + type fields struct { + checkpointer chk.Checkpointer + kclConfig *config.KinesisClientLibConfiguration + mService metrics.MonitoringService + } + tests := []struct { + name string + fields fields + + // testMillis must be at least 200ms or you'll trigger the localTPSExceededError + testMillis time.Duration + expRenewalCalls int + expRenewals int + shardClosed bool + expErr error + }{ + { + "renew once", + fields{ + &mockCheckpointer{}, + &config.KinesisClientLibConfiguration{ + LeaseRefreshWaitTime: 200, + Logger: log, + InitialPositionInStream: config.LATEST, + }, + &mockMetrics{}, + }, + 250, + 1, + 1, + false, + nil, + }, + { + "renew some", + fields{ + &mockCheckpointer{}, + &config.KinesisClientLibConfiguration{ + LeaseRefreshWaitTime: 50, + Logger: log, + InitialPositionInStream: config.LATEST, + }, + &mockMetrics{}, + }, + 50*5 + 10, + 5, + 5, + false, + nil, + }, + { + "renew twice every 2.5 seconds", + fields{ + &mockCheckpointer{}, + &config.KinesisClientLibConfiguration{ + LeaseRefreshWaitTime: 2500, + Logger: log, + InitialPositionInStream: config.LATEST, + }, + &mockMetrics{}, + }, + 5100, + 2, + 2, + false, + nil, + }, + { + "lease error", + fields{ + &mockCheckpointer{fail: true}, + &config.KinesisClientLibConfiguration{ + LeaseRefreshWaitTime: 500, + Logger: log, + InitialPositionInStream: config.LATEST, + }, + &mockMetrics{}, + }, + 1100, + 1, + 0, + false, + getLeaseTestFailure, + }, + } + iterator := "test-iterator" + nextIt := "test-next-iterator" + millisBehind := int64(0) + stopChan := make(chan struct{}) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mk := MockKinesisSubscriberGetter{} + gro := kinesis.GetRecordsOutput{ + Records: []types.Record{ + { + Data: []byte{}, + PartitionKey: new(string), + SequenceNumber: new(string), + ApproximateArrivalTimestamp: &time.Time{}, + EncryptionType: "", + }, + }, + MillisBehindLatest: &millisBehind, + } + if !tt.shardClosed { + gro.NextShardIterator = &nextIt + } + mk.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&gro, nil) + mk.On("GetShardIterator", mock.Anything, mock.Anything, mock.Anything).Return(&kinesis.GetShardIteratorOutput{ShardIterator: &iterator}, nil) + rp := mockRecordProcessor{ + processDurationMillis: tt.testMillis, + } + sc := &PollingShardConsumer{ + commonShardConsumer: commonShardConsumer{ + shard: &par.ShardStatus{ + ID: "test-shard-id", + Mux: &sync.RWMutex{}, + }, + checkpointer: tt.fields.checkpointer, + kclConfig: tt.fields.kclConfig, + kc: &mk, + recordProcessor: &rp, + mService: tt.fields.mService, + }, + stop: &stopChan, + mService: tt.fields.mService, + } + + // Send the stop signal a little before the total time it should + // take to get records and process them. This prevents test time + // errors due to the threads running longer than the test case + // expects. + go func() { + time.Sleep((tt.testMillis - 1) * time.Millisecond) + stopChan <- struct{}{} + }() + + err := sc.getRecords() + + assert.Equal(t, tt.expErr, err) + assert.Equal(t, tt.expRenewalCalls, sc.checkpointer.(*mockCheckpointer).readGetLeaseCalledTimes()) + assert.Equal(t, tt.expRenewals, sc.mService.(*mockMetrics).readLeaseRenewedCalledTimes()) + mk.AssertExpectations(t) + }) + } +} + type mockCheckpointer struct { getLeaseCalledTimes int + gLCTMu sync.Mutex fail bool } -func (m mockCheckpointer) Init() error { return nil } +func (m *mockCheckpointer) readGetLeaseCalledTimes() int { + m.gLCTMu.Lock() + defer m.gLCTMu.Unlock() + return m.getLeaseCalledTimes +} +func (m *mockCheckpointer) Init() error { return nil } func (m *mockCheckpointer) GetLease(*par.ShardStatus, string) error { + m.gLCTMu.Lock() + defer m.gLCTMu.Unlock() m.getLeaseCalledTimes++ if m.fail { return getLeaseTestFailure } return nil } -func (m mockCheckpointer) CheckpointSequence(*par.ShardStatus) error { return nil } -func (m mockCheckpointer) FetchCheckpoint(*par.ShardStatus) error { return nil } -func (m mockCheckpointer) RemoveLeaseInfo(string) error { return nil } -func (m mockCheckpointer) RemoveLeaseOwner(string) error { return nil } -func (m mockCheckpointer) GetLeaseOwner(string) (string, error) { return "", nil } -func (m mockCheckpointer) ListActiveWorkers(map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) { +func (m *mockCheckpointer) CheckpointSequence(*par.ShardStatus) error { return nil } +func (m *mockCheckpointer) FetchCheckpoint(*par.ShardStatus) error { return nil } +func (m *mockCheckpointer) RemoveLeaseInfo(string) error { return nil } +func (m *mockCheckpointer) RemoveLeaseOwner(string) error { return nil } +func (m *mockCheckpointer) GetLeaseOwner(string) (string, error) { return "", nil } +func (m *mockCheckpointer) ListActiveWorkers(map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) { return map[string][]*par.ShardStatus{}, nil } -func (m mockCheckpointer) ClaimShard(*par.ShardStatus, string) error { return nil } +func (m *mockCheckpointer) ClaimShard(*par.ShardStatus, string) error { return nil } + +type mockRecordProcessor struct { + processDurationMillis time.Duration +} + +func (m mockRecordProcessor) Initialize(initializationInput *kcl.InitializationInput) {} +func (m mockRecordProcessor) ProcessRecords(processRecordsInput *kcl.ProcessRecordsInput) { + time.Sleep(time.Millisecond * m.processDurationMillis) +} +func (m mockRecordProcessor) Shutdown(shutdownInput *kcl.ShutdownInput) {} type mockMetrics struct { leaseRenewedCalledTimes int + lRCTMu sync.Mutex } -func (m mockMetrics) Init(appName, streamName, workerID string) error { return nil } -func (m mockMetrics) Start() error { return nil } -func (m mockMetrics) IncrRecordsProcessed(shard string, count int) {} -func (m mockMetrics) IncrBytesProcessed(shard string, count int64) {} -func (m mockMetrics) MillisBehindLatest(shard string, milliSeconds float64) {} -func (m mockMetrics) DeleteMetricMillisBehindLatest(shard string) {} -func (m mockMetrics) LeaseGained(shard string) {} -func (m mockMetrics) LeaseLost(shard string) {} +func (m *mockMetrics) readLeaseRenewedCalledTimes() int { + m.lRCTMu.Lock() + defer m.lRCTMu.Unlock() + return m.leaseRenewedCalledTimes +} +func (m *mockMetrics) Init(appName, streamName, workerID string) error { return nil } +func (m *mockMetrics) Start() error { return nil } +func (m *mockMetrics) IncrRecordsProcessed(shard string, count int) {} +func (m *mockMetrics) IncrBytesProcessed(shard string, count int64) {} +func (m *mockMetrics) MillisBehindLatest(shard string, milliSeconds float64) {} +func (m *mockMetrics) DeleteMetricMillisBehindLatest(shard string) {} +func (m *mockMetrics) LeaseGained(shard string) {} +func (m *mockMetrics) LeaseLost(shard string) {} func (m *mockMetrics) LeaseRenewed(shard string) { + m.lRCTMu.Lock() + defer m.lRCTMu.Unlock() m.leaseRenewedCalledTimes++ } -func (m mockMetrics) RecordGetRecordsTime(shard string, time float64) {} -func (m mockMetrics) RecordProcessRecordsTime(shard string, time float64) {} -func (m mockMetrics) Shutdown() {} +func (m *mockMetrics) RecordGetRecordsTime(shard string, time float64) {} +func (m *mockMetrics) RecordProcessRecordsTime(shard string, time float64) {} +func (m *mockMetrics) Shutdown() {}