diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java index ff7aef75..d0cf98db 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java @@ -25,6 +25,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.argThat; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; @@ -37,11 +38,13 @@ import static org.mockito.Mockito.when; import java.io.File; import java.math.BigInteger; import java.util.ArrayList; +import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.ListIterator; import java.util.Objects; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -439,6 +442,69 @@ public class ShardConsumerTest { file.delete(); } + private String randomShardId() { + return UUID.randomUUID().toString(); + } + + /** + * Test that a consumer can be shut down while it is blocked on parent + */ + @Test + public final void testShardConsumerShutdownWhenBlockedOnParent() throws Exception { + final StreamConfig streamConfig = mock(StreamConfig.class); + final RecordProcessorCheckpointer recordProcessorCheckpointer = mock(RecordProcessorCheckpointer.class); + final GetRecordsCache getRecordsCache = mock(GetRecordsCache.class); + final KinesisDataFetcher dataFetcher = mock(KinesisDataFetcher.class); + when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), + any(IMetricsFactory.class), anyInt())).thenReturn(getRecordsCache); + final String shardId = randomShardId(); + final String parentShardId = randomShardId(); + final ShardInfo shardInfo = mock(ShardInfo.class); + final KinesisClientLease parentLease = mock(KinesisClientLease.class); + when(shardInfo.getShardId()).thenReturn(shardId); + when(shardInfo.getParentShardIds()).thenReturn(Arrays.asList(parentShardId)); + when(leaseManager.getLease(eq(parentShardId))).thenReturn(parentLease); + when(parentLease.getCheckpoint()).thenReturn(ExtendedSequenceNumber.TRIM_HORIZON); + + final ShardConsumer consumer = new ShardConsumer(shardInfo, + streamConfig, + checkpoint, + processor, + recordProcessorCheckpointer, + leaseManager, + parentShardPollIntervalMillis, + cleanupLeasesOfCompletedShards, + executorService, + metricsFactory, + taskBackoffTimeMillis, + KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, + dataFetcher, + Optional.empty(), + Optional.empty(), + config, + shardSyncer, + shardSyncStrategy); + + assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.WAITING_ON_PARENT_SHARDS))); + verify(parentLease, times(0)).getCheckpoint(); + consumer.consumeShard(); // check on parent shards + Thread.sleep(parentShardPollIntervalMillis * 2); + assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.WAITING_ON_PARENT_SHARDS))); + verify(parentLease, times(1)).getCheckpoint(); + consumer.notifyShutdownRequested(shutdownNotification); + verify(shutdownNotification, times(0)).shutdownComplete(); + assertThat(consumer.getShutdownReason(), equalTo(ShutdownReason.REQUESTED)); + assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.WAITING_ON_PARENT_SHARDS))); + consumer.consumeShard(); + assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.SHUTDOWN_COMPLETE))); + assertThat(consumer.isShutdown(), is(true)); + verify(shutdownNotification, times(1)).shutdownComplete(); + consumer.beginShutdown(); + assertThat(consumer.getShutdownReason(), equalTo(ShutdownReason.ZOMBIE)); + assertThat(consumer.isShutdown(), is(true)); + } + + private static final class TransientShutdownErrorTestStreamlet extends TestStreamlet { private final CountDownLatch errorShutdownLatch = new CountDownLatch(1);