diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStates.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStates.java index c8678974..2d92d7d7 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStates.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStates.java @@ -355,7 +355,7 @@ class ConsumerStates { @Override public ITask createTask(ShardConsumer consumer) { return new ShutdownNotificationTask(consumer.getRecordProcessor(), consumer.getRecordProcessorCheckpointer(), - consumer.getShutdownNotification()); + consumer.getShutdownNotification(), consumer.getShardInfo()); } @Override diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownNotificationTask.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownNotificationTask.java index 13711f23..a689ee43 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownNotificationTask.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownNotificationTask.java @@ -12,11 +12,13 @@ class ShutdownNotificationTask implements ITask { private final IRecordProcessor recordProcessor; private final IRecordProcessorCheckpointer recordProcessorCheckpointer; private final ShutdownNotification shutdownNotification; + private final ShardInfo shardInfo; - ShutdownNotificationTask(IRecordProcessor recordProcessor, IRecordProcessorCheckpointer recordProcessorCheckpointer, ShutdownNotification shutdownNotification) { + ShutdownNotificationTask(IRecordProcessor recordProcessor, IRecordProcessorCheckpointer recordProcessorCheckpointer, ShutdownNotification shutdownNotification, ShardInfo shardInfo) { this.recordProcessor = recordProcessor; this.recordProcessorCheckpointer = recordProcessorCheckpointer; this.shutdownNotification = shutdownNotification; + this.shardInfo = shardInfo; } @Override diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java index 2a1e5484..7ebfc3f3 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java @@ -538,15 +538,18 @@ public class Worker implements Runnable { */ public Future requestShutdown() { + // + // Stop accepting new leases. Once we do this we can be sure that + // no more leases will be acquired. + // leaseCoordinator.stopLeaseTaker(); - // - // Stop accepting new leases - // + Collection leases = leaseCoordinator.getAssignments(); if (leases == null || leases.isEmpty()) { // - // If there are no leases shutdown is already completed. + // If there are no leases notification is already completed, but we still need to shutdown the worker. // + this.shutdown(); return Futures.immediateFuture(null); } CountDownLatch shutdownCompleteLatch = new CountDownLatch(leases.size()); @@ -555,7 +558,18 @@ public class Worker implements Runnable { ShutdownNotification shutdownNotification = new ShardConsumerShutdownNotification(leaseCoordinator, lease, notificationCompleteLatch, shutdownCompleteLatch); ShardInfo shardInfo = KinesisClientLibLeaseCoordinator.convertLeaseToAssignment(lease); - shardInfoShardConsumerMap.get(shardInfo).notifyShutdownRequested(shutdownNotification); + ShardConsumer consumer = shardInfoShardConsumerMap.get(shardInfo); + if (consumer != null) { + consumer.notifyShutdownRequested(shutdownNotification); + } else { + // + // There is a race condition between retrieving the current assignments, and creating the + // notification. If the a lease is lost in between these two points, we explicitly decrement the + // notification latches to clear the shutdown. + // + notificationCompleteLatch.countDown(); + shutdownCompleteLatch.countDown(); + } } return new ShutdownFuture(shutdownCompleteLatch, notificationCompleteLatch, this); @@ -622,9 +636,11 @@ public class Worker implements Runnable { /** * Returns whether worker can shutdown immediately. Note that this method is called from Worker's {{@link #run()} * method before every loop run, so method must do minimum amount of work to not impact shard processing timings. + * * @return Whether worker should shutdown immediately. */ - private boolean shouldShutdown() { + @VisibleForTesting + boolean shouldShutdown() { if (executorService.isShutdown()) { LOG.error("Worker executor service has been shutdown, so record processors cannot be shutdown."); return true; diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java index baafa447..daf58165 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java @@ -19,6 +19,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.isA; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; @@ -27,6 +28,7 @@ import static org.mockito.Mockito.*; import java.io.File; import java.lang.Thread.State; +import java.lang.reflect.Field; import java.math.BigInteger; import java.util.ArrayList; import java.util.Collections; @@ -47,6 +49,7 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.logging.Log; @@ -96,6 +99,8 @@ import com.amazonaws.services.kinesis.model.Shard; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import lombok.RequiredArgsConstructor; + /** * Unit tests of Worker. */ @@ -776,6 +781,276 @@ public class WorkerTest { } + @Test + public void testRequestShutdownNoLeases() throws Exception { + + + IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); + StreamConfig streamConfig = mock(StreamConfig.class); + IMetricsFactory metricsFactory = mock(IMetricsFactory.class); + + + final List leases = new ArrayList<>(); + final List currentAssignments = new ArrayList<>(); + + when(leaseCoordinator.getAssignments()).thenAnswer(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return leases; + } + }); + when(leaseCoordinator.getCurrentAssignments()).thenAnswer(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return currentAssignments; + } + }); + + IRecordProcessor processor = mock(IRecordProcessor.class); + when(recordProcessorFactory.createProcessor()).thenReturn(processor); + + + Worker worker = new Worker("testRequestShutdown", recordProcessorFactory, streamConfig, + INITIAL_POSITION_TRIM_HORIZON, parentShardPollIntervalMillis, shardSyncIntervalMillis, + cleanupLeasesUponShardCompletion, leaseCoordinator, leaseCoordinator, executorService, metricsFactory, + taskBackoffTimeMillis, failoverTimeMillis, false, shardPrioritization); + + when(executorService.submit(Matchers.> any())) + .thenAnswer(new ShutdownHandlingAnswer(taskFuture)); + when(taskFuture.isDone()).thenReturn(true); + when(taskFuture.get()).thenReturn(taskResult); + + worker.runProcessLoop(); + + verify(executorService, never()).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.BLOCK_ON_PARENT_SHARDS)))); + + worker.runProcessLoop(); + + verify(executorService, never()).submit(argThat( + both(isA(MetricsCollectingTaskDecorator.class)).and(TaskTypeMatcher.isOfType(TaskType.INITIALIZE)))); + + worker.requestShutdown(); + worker.runProcessLoop(); + + verify(executorService, never()).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN_NOTIFICATION)))); + + worker.runProcessLoop(); + verify(executorService, never()).submit(argThat( + both(isA(MetricsCollectingTaskDecorator.class)).and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN)))); + + assertThat(worker.shouldShutdown(), equalTo(true)); + + } + + @Test + public void testRequestShutdownWithLostLease() throws Exception { + + IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); + StreamConfig streamConfig = mock(StreamConfig.class); + IMetricsFactory metricsFactory = mock(IMetricsFactory.class); + + ExtendedSequenceNumber checkpoint = new ExtendedSequenceNumber("123", 0L); + KinesisClientLease lease1 = makeLease(checkpoint, 1); + KinesisClientLease lease2 = makeLease(checkpoint, 2); + final List leases = new ArrayList<>(); + final List currentAssignments = new ArrayList<>(); + leases.add(lease1); + leases.add(lease2); + + ShardInfo shardInfo1 = makeShardInfo(lease1); + currentAssignments.add(shardInfo1); + ShardInfo shardInfo2 = makeShardInfo(lease2); + currentAssignments.add(shardInfo2); + + when(leaseCoordinator.getAssignments()).thenAnswer(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return leases; + } + }); + when(leaseCoordinator.getCurrentAssignments()).thenAnswer(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return currentAssignments; + } + }); + + IRecordProcessor processor = mock(IRecordProcessor.class); + when(recordProcessorFactory.createProcessor()).thenReturn(processor); + + Worker worker = new Worker("testRequestShutdown", recordProcessorFactory, streamConfig, + INITIAL_POSITION_TRIM_HORIZON, parentShardPollIntervalMillis, shardSyncIntervalMillis, + cleanupLeasesUponShardCompletion, leaseCoordinator, leaseCoordinator, executorService, metricsFactory, + taskBackoffTimeMillis, failoverTimeMillis, false, shardPrioritization); + + when(executorService.submit(Matchers.> any())) + .thenAnswer(new ShutdownHandlingAnswer(taskFuture)); + when(taskFuture.isDone()).thenReturn(true); + when(taskFuture.get()).thenReturn(taskResult); + + worker.runProcessLoop(); + + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.BLOCK_ON_PARENT_SHARDS)).and(ReflectionFieldMatcher + .withField(BlockOnParentShardTask.class, "shardInfo", equalTo(shardInfo1))))); + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.BLOCK_ON_PARENT_SHARDS)).and(ReflectionFieldMatcher + .withField(BlockOnParentShardTask.class, "shardInfo", equalTo(shardInfo2))))); + + worker.runProcessLoop(); + + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.INITIALIZE)).and(ReflectionFieldMatcher + .withField(InitializeTask.class, "shardInfo", equalTo(shardInfo1))))); + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.INITIALIZE)).and(ReflectionFieldMatcher + .withField(InitializeTask.class, "shardInfo", equalTo(shardInfo2))))); + + worker.getShardInfoShardConsumerMap().remove(shardInfo2); + worker.requestShutdown(); + leases.remove(1); + currentAssignments.remove(1); + worker.runProcessLoop(); + + + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN_NOTIFICATION)).and(ReflectionFieldMatcher + .withField(ShutdownNotificationTask.class, "shardInfo", equalTo(shardInfo1))))); + + verify(executorService, never()).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN_NOTIFICATION)).and(ReflectionFieldMatcher + .withField(ShutdownNotificationTask.class, "shardInfo", equalTo(shardInfo2))))); + + worker.runProcessLoop(); + + verify(leaseCoordinator).dropLease(eq(lease1)); + verify(leaseCoordinator, never()).dropLease(eq(lease2)); + leases.clear(); + currentAssignments.clear(); + + worker.runProcessLoop(); + + verify(executorService, atLeastOnce()).submit(argThat( + both(isA(MetricsCollectingTaskDecorator.class)).and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN)))); + + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN)).and(ReflectionFieldMatcher + .withField(ShutdownTask.class, "shardInfo", equalTo(shardInfo1))))); + + verify(executorService, never()).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN)).and(ReflectionFieldMatcher + .withField(ShutdownTask.class, "shardInfo", equalTo(shardInfo2))))); + + } + + @Test + public void testRequestShutdownWithAllLeasesLost() throws Exception { + + IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); + StreamConfig streamConfig = mock(StreamConfig.class); + IMetricsFactory metricsFactory = mock(IMetricsFactory.class); + + ExtendedSequenceNumber checkpoint = new ExtendedSequenceNumber("123", 0L); + KinesisClientLease lease1 = makeLease(checkpoint, 1); + KinesisClientLease lease2 = makeLease(checkpoint, 2); + final List leases = new ArrayList<>(); + final List currentAssignments = new ArrayList<>(); + leases.add(lease1); + leases.add(lease2); + + ShardInfo shardInfo1 = makeShardInfo(lease1); + currentAssignments.add(shardInfo1); + ShardInfo shardInfo2 = makeShardInfo(lease2); + currentAssignments.add(shardInfo2); + + when(leaseCoordinator.getAssignments()).thenAnswer(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return leases; + } + }); + when(leaseCoordinator.getCurrentAssignments()).thenAnswer(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return currentAssignments; + } + }); + + IRecordProcessor processor = mock(IRecordProcessor.class); + when(recordProcessorFactory.createProcessor()).thenReturn(processor); + + Worker worker = new Worker("testRequestShutdown", recordProcessorFactory, streamConfig, + INITIAL_POSITION_TRIM_HORIZON, parentShardPollIntervalMillis, shardSyncIntervalMillis, + cleanupLeasesUponShardCompletion, leaseCoordinator, leaseCoordinator, executorService, metricsFactory, + taskBackoffTimeMillis, failoverTimeMillis, false, shardPrioritization); + + when(executorService.submit(Matchers.> any())) + .thenAnswer(new ShutdownHandlingAnswer(taskFuture)); + when(taskFuture.isDone()).thenReturn(true); + when(taskFuture.get()).thenReturn(taskResult); + + worker.runProcessLoop(); + + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.BLOCK_ON_PARENT_SHARDS)).and(ReflectionFieldMatcher + .withField(BlockOnParentShardTask.class, "shardInfo", equalTo(shardInfo1))))); + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.BLOCK_ON_PARENT_SHARDS)).and(ReflectionFieldMatcher + .withField(BlockOnParentShardTask.class, "shardInfo", equalTo(shardInfo2))))); + + worker.runProcessLoop(); + + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.INITIALIZE)).and(ReflectionFieldMatcher + .withField(InitializeTask.class, "shardInfo", equalTo(shardInfo1))))); + verify(executorService).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.INITIALIZE)).and(ReflectionFieldMatcher + .withField(InitializeTask.class, "shardInfo", equalTo(shardInfo2))))); + + worker.getShardInfoShardConsumerMap().clear(); + Future future = worker.requestShutdown(); + + leases.clear(); + currentAssignments.clear(); + + try { + future.get(1, TimeUnit.HOURS); + } catch (TimeoutException te) { + fail("Future from requestShutdown should immediately return."); + } + + worker.runProcessLoop(); + verify(executorService, never()).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN_NOTIFICATION)).and(ReflectionFieldMatcher + .withField(ShutdownNotificationTask.class, "shardInfo", equalTo(shardInfo1))))); + + verify(executorService, never()).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN_NOTIFICATION)).and(ReflectionFieldMatcher + .withField(ShutdownNotificationTask.class, "shardInfo", equalTo(shardInfo2))))); + + worker.runProcessLoop(); + + verify(leaseCoordinator, never()).dropLease(eq(lease1)); + verify(leaseCoordinator, never()).dropLease(eq(lease2)); + + worker.runProcessLoop(); + + verify(executorService, never()).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN)).and(ReflectionFieldMatcher + .withField(ShutdownTask.class, "shardInfo", equalTo(shardInfo1))))); + + verify(executorService, never()).submit(argThat(both(isA(MetricsCollectingTaskDecorator.class)) + .and(TaskTypeMatcher.isOfType(TaskType.SHUTDOWN)).and(ReflectionFieldMatcher + .withField(ShutdownTask.class, "shardInfo", equalTo(shardInfo2))))); + + + + assertThat(worker.shouldShutdown(), equalTo(true)); + + } + @Test public void testLeaseCancelledAfterShutdownRequest() throws Exception { @@ -919,6 +1194,17 @@ public class WorkerTest { } + private KinesisClientLease makeLease(ExtendedSequenceNumber checkpoint, int shardId) { + return new KinesisClientLeaseBuilder().withCheckpoint(checkpoint).withConcurrencyToken(UUID.randomUUID()) + .withLastCounterIncrementNanos(0L).withLeaseCounter(0L).withOwnerSwitchesSinceCheckpoint(0L) + .withLeaseOwner("Self").withLeaseKey(String.format("shardId-%03d", shardId)).build(); + } + + private ShardInfo makeShardInfo(KinesisClientLease lease) { + return new ShardInfo(lease.getLeaseKey(), lease.getConcurrencyToken().toString(), lease.getParentShardIds(), + lease.getCheckpoint()); + } + private static class ShutdownReasonMatcher extends TypeSafeDiagnosingMatcher { private final Matcher matcher; @@ -1012,9 +1298,9 @@ public class WorkerTest { private static class InnerTaskMatcher extends TypeSafeMatcher { - final Matcher matcher; + final Matcher matcher; - InnerTaskMatcher(Matcher matcher) { + InnerTaskMatcher(Matcher matcher) { this.matcher = matcher; } @@ -1028,10 +1314,60 @@ public class WorkerTest { matcher.describeTo(description); } - static InnerTaskMatcher taskWith(Class clazz, Matcher matcher) { + static InnerTaskMatcher taskWith(Class clazz, Matcher matcher) { return new InnerTaskMatcher<>(matcher); } } + + @RequiredArgsConstructor + private static class ReflectionFieldMatcher + extends TypeSafeDiagnosingMatcher { + + private final Class itemClass; + private final String fieldName; + private final Matcher fieldMatcher; + + @Override + protected boolean matchesSafely(MetricsCollectingTaskDecorator item, Description mismatchDescription) { + if (item.getOther() == null) { + mismatchDescription.appendText("inner task is null"); + return false; + } + ITask inner = item.getOther(); + if (!itemClass.equals(inner.getClass())) { + mismatchDescription.appendText("inner task isn't an instance of ").appendText(itemClass.getName()); + return false; + } + try { + Field field = itemClass.getDeclaredField(fieldName); + field.setAccessible(true); + if (!fieldMatcher.matches(field.get(inner))) { + mismatchDescription.appendText("Field '").appendText(fieldName).appendText("' doesn't match: ") + .appendDescriptionOf(fieldMatcher); + return false; + } + return true; + } catch (NoSuchFieldException e) { + mismatchDescription.appendText(itemClass.getName()).appendText(" doesn't have a field named ") + .appendText(fieldName); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + + return false; + } + + @Override + public void describeTo(Description description) { + description.appendText("An item of ").appendText(itemClass.getName()).appendText(" with the field '") + .appendText(fieldName).appendText("' matching ").appendDescriptionOf(fieldMatcher); + } + + static ReflectionFieldMatcher withField(Class itemClass, String fieldName, + Matcher fieldMatcher) { + return new ReflectionFieldMatcher<>(itemClass, fieldName, fieldMatcher); + } + } /** * Returns executor service that will be owned by the worker. This is useful to test the scenario * where worker shuts down the executor service also during shutdown flow.