diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java index 394f9486..a30412ce 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java @@ -423,6 +423,8 @@ class ShardConsumer { } if (isShutdownRequested() && taskOutcome != TaskOutcome.FAILURE) { currentState = currentState.shutdownTransition(shutdownReason); + } else if (isShutdownRequested() && ConsumerStates.ShardConsumerState.WAITING_ON_PARENT_SHARDS.equals(currentState.getState())) { + currentState = currentState.shutdownTransition(shutdownReason); } else if (taskOutcome == TaskOutcome.SUCCESSFUL) { if (currentState.getTaskType() == currentTask.getTaskType()) { currentState = currentState.successTransition(); diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java index 1a7fa58f..185113fc 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java @@ -903,16 +903,19 @@ public class Worker implements Runnable { lease, notificationCompleteLatch, shutdownCompleteLatch); ShardInfo shardInfo = KinesisClientLibLeaseCoordinator.convertLeaseToAssignment(lease); ShardConsumer consumer = shardInfoShardConsumerMap.get(shardInfo); - if (consumer != null) { - consumer.notifyShutdownRequested(shutdownNotification); - } else { + + if (consumer == null || ConsumerStates.ShardConsumerState.SHUTDOWN_COMPLETE.equals(consumer.getCurrentState())) { // - // There is a race condition between retrieving the current assignments, and creating the + // CASE1: There is a race condition between retrieving the current assignments, and creating the // notification. If the a lease is lost in between these two points, we explicitly decrement the // notification latches to clear the shutdown. // + // CASE2: The shard consumer is in SHUTDOWN_COMPLETE state and will not decrement the latches by itself. + // notificationCompleteLatch.countDown(); shutdownCompleteLatch.countDown(); + } else { + consumer.notifyShutdownRequested(shutdownNotification); } } return new GracefulShutdownContext(shutdownCompleteLatch, notificationCompleteLatch, this); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java index f50069d3..1cf86c4f 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java @@ -25,6 +25,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.argThat; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; @@ -37,11 +38,13 @@ import static org.mockito.Mockito.when; import java.io.File; import java.math.BigInteger; import java.util.ArrayList; +import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.ListIterator; import java.util.Objects; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -446,6 +449,71 @@ public class ShardConsumerTest { file.delete(); } + private String randomShardId() { + return UUID.randomUUID().toString(); + } + + /** + * Test that a consumer can be shut down while it is blocked on parent + */ + @Test + public final void testShardConsumerShutdownWhenBlockedOnParent() throws Exception { + final StreamConfig streamConfig = mock(StreamConfig.class); + final RecordProcessorCheckpointer recordProcessorCheckpointer = mock(RecordProcessorCheckpointer.class); + final GetRecordsCache getRecordsCache = mock(GetRecordsCache.class); + final KinesisDataFetcher dataFetcher = mock(KinesisDataFetcher.class); + when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), + any(IMetricsFactory.class), anyInt())).thenReturn(getRecordsCache); + final String shardId = randomShardId(); + final String parentShardId = randomShardId(); + final ShardInfo shardInfo = mock(ShardInfo.class); + final KinesisClientLease parentLease = mock(KinesisClientLease.class); + when(shardInfo.getShardId()).thenReturn(shardId); + when(shardInfo.getParentShardIds()).thenReturn(Arrays.asList(parentShardId)); + when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); + when(leaseManager.getLease(eq(parentShardId))).thenReturn(parentLease); + when(parentLease.getCheckpoint()).thenReturn(ExtendedSequenceNumber.TRIM_HORIZON); + + final ShardConsumer consumer = + new ShardConsumer(shardInfo, + streamConfig, + checkpoint, + processor, + recordProcessorCheckpointer, + leaseCoordinator, + parentShardPollIntervalMillis, + cleanupLeasesOfCompletedShards, + executorService, + metricsFactory, + taskBackoffTimeMillis, + KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, + dataFetcher, + Optional.empty(), + Optional.empty(), + config, + shardSyncer, + shardSyncStrategy); + + assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.WAITING_ON_PARENT_SHARDS))); + verify(parentLease, times(0)).getCheckpoint(); + consumer.consumeShard(); // check on parent shards + Thread.sleep(parentShardPollIntervalMillis * 2); + assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.WAITING_ON_PARENT_SHARDS))); + verify(parentLease, times(1)).getCheckpoint(); + consumer.notifyShutdownRequested(shutdownNotification); + verify(shutdownNotification, times(0)).shutdownComplete(); + assertThat(consumer.getShutdownReason(), equalTo(ShutdownReason.REQUESTED)); + assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.WAITING_ON_PARENT_SHARDS))); + consumer.consumeShard(); + assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.SHUTDOWN_COMPLETE))); + assertThat(consumer.isShutdown(), is(true)); + verify(shutdownNotification, times(1)).shutdownComplete(); + consumer.beginShutdown(); + assertThat(consumer.getShutdownReason(), equalTo(ShutdownReason.ZOMBIE)); + assertThat(consumer.isShutdown(), is(true)); + } + + private static final class TransientShutdownErrorTestStreamlet extends TestStreamlet { private final CountDownLatch errorShutdownLatch = new CountDownLatch(1); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java index 23c91269..438215f1 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java @@ -21,7 +21,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; @@ -53,7 +52,6 @@ import java.util.HashSet; import java.util.List; import java.util.ListIterator; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.Callable; @@ -1100,6 +1098,63 @@ public class WorkerTest { } + private String randomShardId() { + return UUID.randomUUID().toString(); + } + + @Test + public void testShutdownDoesNotBlockOnCompletedLeases() throws Exception { + final String shardId = randomShardId(); + final String parentShardId = randomShardId(); + final KinesisClientLease completedLease = mock(KinesisClientLease.class); + when(completedLease.getLeaseKey()).thenReturn(shardId); + when(completedLease.getParentShardIds()).thenReturn(Collections.singleton(parentShardId)); + when(completedLease.getCheckpoint()).thenReturn(ExtendedSequenceNumber.SHARD_END); + when(completedLease.getConcurrencyToken()).thenReturn(UUID.randomUUID()); + final StreamConfig streamConfig = mock(StreamConfig.class); + final IMetricsFactory metricsFactory = mock(IMetricsFactory.class); + final List leases = Collections.singletonList(completedLease); + final List currentAssignments = new ArrayList<>(); + + when(leaseCoordinator.getAssignments()).thenAnswer((Answer>) invocation -> leases); + when(leaseCoordinator.getCurrentAssignments()).thenAnswer((Answer>) invocation -> currentAssignments); + + final IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); + final IRecordProcessor processor = mock(IRecordProcessor.class); + when(recordProcessorFactory.createProcessor()).thenReturn(processor); + + Worker worker = new Worker("testShutdownWithCompletedLeases", + recordProcessorFactory, + config, + streamConfig, + INITIAL_POSITION_TRIM_HORIZON, + parentShardPollIntervalMillis, + shardSyncIntervalMillis, + cleanupLeasesUponShardCompletion, + leaseCoordinator, + leaseCoordinator, + executorService, + metricsFactory, + taskBackoffTimeMillis, + failoverTimeMillis, + false, + shardPrioritization); + + final Map shardInfoShardConsumerMap = worker.getShardInfoShardConsumerMap(); + final ShardInfo completedShardInfo = KinesisClientLibLeaseCoordinator.convertLeaseToAssignment(completedLease); + final ShardConsumer completedShardConsumer = mock(ShardConsumer.class); + shardInfoShardConsumerMap.put(completedShardInfo, completedShardConsumer); + when(completedShardConsumer.getCurrentState()).thenReturn(ConsumerStates.ShardConsumerState.SHUTDOWN_COMPLETE); + + Callable callable = worker.createWorkerShutdownCallable(); + assertThat(worker.hasGracefulShutdownStarted(), equalTo(false)); + + GracefulShutdownContext gracefulShutdownContext = callable.call(); + assertThat(gracefulShutdownContext.getShutdownCompleteLatch().getCount(), equalTo(0L)); + assertThat(gracefulShutdownContext.getNotificationCompleteLatch().getCount(), equalTo(0L)); + assertThat(worker.hasGracefulShutdownStarted(), equalTo(true)); + } + @Test public void testRequestShutdownWithLostLease() throws Exception {