diff --git a/clientlibrary/checkpoint/checkpointer.go b/clientlibrary/checkpoint/checkpointer.go index 1af66ba..3f5d179 100644 --- a/clientlibrary/checkpoint/checkpointer.go +++ b/clientlibrary/checkpoint/checkpointer.go @@ -79,6 +79,9 @@ type Checkpointer interface { // RemoveLeaseOwner to remove lease owner for the shard entry to make the shard available for reassignment RemoveLeaseOwner(string) error + // GetLeaseOwner to get current owner of lease for shard + GetLeaseOwner(string) (string, error) + // ListActiveWorkers returns active workers and their shards (New Lease Stealing Methods) ListActiveWorkers(map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) diff --git a/clientlibrary/checkpoint/dynamodb-checkpointer.go b/clientlibrary/checkpoint/dynamodb-checkpointer.go index b8f12d8..ee14fee 100644 --- a/clientlibrary/checkpoint/dynamodb-checkpointer.go +++ b/clientlibrary/checkpoint/dynamodb-checkpointer.go @@ -51,6 +51,10 @@ const ( NumMaxRetries = 10 ) +var ( + NoLeaseOwnerErr = errors.New("no LeaseOwner in checkpoints table") +) + // DynamoCheckpoint implements the Checkpoint interface using DynamoDB as a backend type DynamoCheckpoint struct { log logger.Logger @@ -336,6 +340,23 @@ func (checkpointer *DynamoCheckpoint) RemoveLeaseOwner(shardID string) error { return err } +// GetLeaseOwner returns current lease owner of given shard in checkpoints table +func (checkpointer *DynamoCheckpoint) GetLeaseOwner(shardID string) (string, error) { + currentCheckpoint, err := checkpointer.getItem(shardID) + if err != nil { + return "", err + } + + assignedVar, assignedToOk := currentCheckpoint[LeaseOwnerKey] + + if !assignedToOk { + return "", NoLeaseOwnerErr + } + + return assignedVar.(*types.AttributeValueMemberS).Value, nil + +} + // ListActiveWorkers returns a map of workers and their shards func (checkpointer *DynamoCheckpoint) ListActiveWorkers(shardStatus map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) { err := checkpointer.syncLeases(shardStatus) diff --git a/clientlibrary/interfaces/record-processor.go b/clientlibrary/interfaces/record-processor.go index 1c41d56..a4897d4 100644 --- a/clientlibrary/interfaces/record-processor.go +++ b/clientlibrary/interfaces/record-processor.go @@ -59,7 +59,7 @@ type ( * @param processRecordsInput Provides the records to be processed as well as information and capabilities related * to them (eg checkpointing). */ - ProcessRecords(processRecordsInput *ProcessRecordsInput) + ProcessRecords(processRecordsInput *ProcessRecordsInput) error // Shutdown /* diff --git a/clientlibrary/worker/common-shard-consumer.go b/clientlibrary/worker/common-shard-consumer.go index 68ec1b3..36ddb77 100644 --- a/clientlibrary/worker/common-shard-consumer.go +++ b/clientlibrary/worker/common-shard-consumer.go @@ -136,7 +136,7 @@ func (sc *commonShardConsumer) waitOnParentShard() error { } } -func (sc *commonShardConsumer) processRecords(getRecordsStartTime time.Time, records []types.Record, millisBehindLatest *int64, recordCheckpointer kcl.IRecordProcessorCheckpointer) { +func (sc *commonShardConsumer) processRecords(getRecordsStartTime time.Time, records []types.Record, millisBehindLatest *int64, recordCheckpointer kcl.IRecordProcessorCheckpointer) error { log := sc.kclConfig.Logger getRecordsTime := time.Since(getRecordsStartTime).Milliseconds() @@ -172,7 +172,10 @@ func (sc *commonShardConsumer) processRecords(getRecordsStartTime time.Time, rec // Delivery the events to the record processor input.CacheEntryTime = &getRecordsStartTime input.CacheExitTime = &processRecordsStartTime - sc.recordProcessor.ProcessRecords(input) + err := sc.recordProcessor.ProcessRecords(input) + if err != nil { + return err + } processedRecordsTiming := time.Since(processRecordsStartTime).Milliseconds() sc.mService.RecordProcessRecordsTime(sc.shard.ID, float64(processedRecordsTiming)) @@ -181,4 +184,5 @@ func (sc *commonShardConsumer) processRecords(getRecordsStartTime time.Time, rec sc.mService.IncrRecordsProcessed(sc.shard.ID, recordLength) sc.mService.IncrBytesProcessed(sc.shard.ID, recordBytes) sc.mService.MillisBehindLatest(sc.shard.ID, float64(*millisBehindLatest)) + return nil } diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index e0998ec..7211842 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -214,7 +214,10 @@ func (sc *PollingShardConsumer) getRecords() error { // reset the retry count after success retriedErrors = 0 - sc.processRecords(getRecordsStartTime, getResp.Records, getResp.MillisBehindLatest, recordCheckpointer) + err = sc.processRecords(getRecordsStartTime, getResp.Records, getResp.MillisBehindLatest, recordCheckpointer) + if err != nil { + return err + } // The shard has been closed, so no new records can be read from it if getResp.NextShardIterator == nil { diff --git a/clientlibrary/worker/record-processor-checkpointer.go b/clientlibrary/worker/record-processor-checkpointer.go index 5544a86..101137f 100644 --- a/clientlibrary/worker/record-processor-checkpointer.go +++ b/clientlibrary/worker/record-processor-checkpointer.go @@ -21,11 +21,17 @@ package worker import ( + "errors" "github.com/aws/aws-sdk-go-v2/aws" - chk "github.com/vmware/vmware-go-kcl-v2/clientlibrary/checkpoint" kcl "github.com/vmware/vmware-go-kcl-v2/clientlibrary/interfaces" par "github.com/vmware/vmware-go-kcl-v2/clientlibrary/partition" + "time" +) + +var ( + ShutdownError = errors.New("another instance may have started processing some of these records already") + LeaseExpiredError = errors.New("the lease has on the shard has expired") ) type ( @@ -69,6 +75,17 @@ func (pc *PreparedCheckpointer) Checkpoint() error { } func (rc *RecordProcessorCheckpointer) Checkpoint(sequenceNumber *string) error { + // return shutdown error if lease is expired or another worker has started processing records for this shard + currLeaseOwner, err := rc.checkpoint.GetLeaseOwner(rc.shard.ID) + if err != nil { + return err + } + if rc.shard.AssignedTo != currLeaseOwner { + return ShutdownError + } + if time.Now().After(rc.shard.LeaseTimeout) { + return LeaseExpiredError + } // checkpoint the last sequence of a closed shard if sequenceNumber == nil { rc.shard.SetCheckpoint(chk.ShardEnd)