From 4482696d955652a955957231d3cb42f162daeefa Mon Sep 17 00:00:00 2001 From: Shiva Pentakota Date: Thu, 6 Apr 2023 17:41:46 -0700 Subject: [PATCH] fix: pass in ctx with cancel for renewLease Signed-off-by: Shiva Pentakota --- .../worker/polling-shard-consumer.go | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index d8c81c1..3829850 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -99,7 +99,12 @@ func (sc *PollingShardConsumer) getShardIterator() (*string, error) { // getRecords continuously poll one shard for data record // Precondition: it currently has the lease on the shard. func (sc *PollingShardConsumer) getRecords() error { - defer sc.releaseLease(sc.shard.ID) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer func() { + // cancel renewLease() + cancelFunc() + sc.releaseLease(sc.shard.ID) + }() log := sc.kclConfig.Logger @@ -133,10 +138,11 @@ func (sc *PollingShardConsumer) getRecords() error { sc.callsLeft = kinesisReadTPSLimit sc.bytesRead = 0 sc.remBytes = MaxBytes + // starting async lease renewal thread leaseRenewalErrChan := make(chan error, 1) go func() { - leaseRenewalErrChan <- sc.renewLease() + leaseRenewalErrChan <- sc.renewLease(ctx) }() for { getRecordsStartTime := time.Now() @@ -300,18 +306,29 @@ func (sc *PollingShardConsumer) callGetRecordsAPI(gri *kinesis.GetRecordsInput) return getResp, 0, err } -func (sc *PollingShardConsumer) renewLease() error { +func (sc *PollingShardConsumer) renewLease(ctx context.Context) error { + renewDuration := time.Duration(sc.kclConfig.LeaseRefreshWaitTime) * time.Millisecond for { - time.Sleep(time.Duration(sc.kclConfig.LeaseRefreshWaitTime) * time.Millisecond) - log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID) - err := sc.checkpointer.GetLease(sc.shard, sc.consumerID) - if err != nil { - // log and return error - log.Errorf("Error in refreshing lease on shard: %s for worker: %s. Error: %+v", - sc.shard.ID, sc.consumerID, err) - return err + timer := time.NewTimer(renewDuration) + select { + case <-timer.C: + log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID) + err := sc.checkpointer.GetLease(sc.shard, sc.consumerID) + if err != nil { + // log and return error + log.Errorf("Error in refreshing lease on shard: %s for worker: %s. Error: %+v", + sc.shard.ID, sc.consumerID, err) + return err + } + // log metric for renewed lease for worker + sc.mService.LeaseRenewed(sc.shard.ID) + case <-ctx.Done(): + // clean up timer resources + if !timer.Stop() { + <-timer.C + } + log.Debugf("renewLease was canceled") + return nil } - // log metric for renewed lease for worker - sc.mService.LeaseRenewed(sc.shard.ID) } }