diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RequestedShutdownCoordinator.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RequestedShutdownCoordinator.java index 76139149..5770c6ee 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RequestedShutdownCoordinator.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RequestedShutdownCoordinator.java @@ -1,47 +1,63 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; +import java.util.concurrent.TimeUnit; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import java.util.concurrent.Callable; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +class RequestedShutdownCoordinator { -public class RequestedShutdownCoordinator { + static Future startRequestedShutdown(Callable shutdownCallable) { + FutureTask task = new FutureTask<>(shutdownCallable); + Thread shutdownThread = new Thread(task, "RequestedShutdownThread"); + shutdownThread.start(); + return task; - private final ExecutorService executorService; - - RequestedShutdownCoordinator(ExecutorService executorService) { - this.executorService = executorService; } - static class RequestedShutdownCallable implements Callable { + static Callable createRequestedShutdownCallable(CountDownLatch shutdownCompleteLatch, + CountDownLatch notificationCompleteLatch, Worker worker) { + return new RequestedShutdownCallable(shutdownCompleteLatch, notificationCompleteLatch, worker); + } + + static class RequestedShutdownCallable implements Callable { private static final Log log = LogFactory.getLog(RequestedShutdownCallable.class); private final CountDownLatch shutdownCompleteLatch; private final CountDownLatch notificationCompleteLatch; private final Worker worker; - private final ExecutorService shutdownExecutor; - RequestedShutdownCallable(CountDownLatch shutdownCompleteLatch, CountDownLatch notificationCompleteLatch, Worker worker, ExecutorService shutdownExecutor) { + RequestedShutdownCallable(CountDownLatch shutdownCompleteLatch, CountDownLatch notificationCompleteLatch, + Worker worker) { this.shutdownCompleteLatch = shutdownCompleteLatch; this.notificationCompleteLatch = notificationCompleteLatch; this.worker = worker; - this.shutdownExecutor = shutdownExecutor; } private boolean isWorkerShutdownComplete() { return worker.isShutdownComplete() || worker.getShardInfoShardConsumerMap().isEmpty(); } - private long outstandingRecordProcessors(long timeout, TimeUnit unit) - throws InterruptedException, ExecutionException, TimeoutException { + private String awaitingLogMessage() { + long awaitingNotification = notificationCompleteLatch.getCount(); + long awaitingFinalShutdown = shutdownCompleteLatch.getCount(); - final long startNanos = System.nanoTime(); + return String.format( + "Waiting for %d record process to complete shutdown notification, and %d record processor to complete final shutdown ", + awaitingNotification, awaitingFinalShutdown); + } + + private String awaitingFinalShutdownMessage() { + long outstanding = shutdownCompleteLatch.getCount(); + return String.format("Waiting for %d record processors to complete final shutdown", outstanding); + } + + private boolean waitForRecordProcessors() { // // Awaiting for all ShardConsumer/RecordProcessors to be notified that a shutdown has been requested. @@ -49,54 +65,59 @@ public class RequestedShutdownCoordinator { // notification is started, but before the ShardConsumer is sent the notification. In this case the // ShardConsumer would start the lease loss shutdown, and may never call the notification methods. // - if (!notificationCompleteLatch.await(timeout, unit)) { - long awaitingNotification = notificationCompleteLatch.getCount(); - long awaitingFinalShutdown = shutdownCompleteLatch.getCount(); - log.info("Awaiting " + awaitingNotification + " record processors to complete shutdown notification, and " - + awaitingFinalShutdown + " awaiting final shutdown"); - if (awaitingFinalShutdown != 0) { - // - // The number of record processor awaiting final shutdown should be a superset of the those awaiting - // notification - // - return checkWorkerShutdownMiss(awaitingFinalShutdown); + try { + while (!notificationCompleteLatch.await(1, TimeUnit.SECONDS)) { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + log.info(awaitingLogMessage()); + if (workerShutdownWithRemaining(shutdownCompleteLatch.getCount())) { + return false; + } } + } catch (InterruptedException ie) { + log.warn("Interrupted while waiting for notification complete, terminating shutdown. " + + awaitingLogMessage()); + return false; } - long remaining = remainingTimeout(timeout, unit, startNanos); - throwTimeoutMessageIfExceeded(remaining, "Notification hasn't completed within timeout time."); + if (Thread.interrupted()) { + log.warn("Interrupted before worker shutdown, terminating shutdown"); + return false; + } // // Once all record processors have been notified of the shutdown it is safe to allow the worker to // start its shutdown behavior. Once shutdown starts it will stop renewer, and drop any remaining leases. // worker.shutdown(); - remaining = remainingTimeout(timeout, unit, startNanos); - throwTimeoutMessageIfExceeded(remaining, "Shutdown hasn't completed within timeout time."); + + if (Thread.interrupted()) { + log.warn("Interrupted after worker shutdown, terminating shutdown"); + return false; + } // // Want to wait for all the remaining ShardConsumers/RecordProcessor's to complete their final shutdown // processing. This should really be a no-op since as part of the notification completion the lease for // ShardConsumer is terminated. // - if (!shutdownCompleteLatch.await(remaining, TimeUnit.NANOSECONDS)) { - long outstanding = shutdownCompleteLatch.getCount(); - log.info("Awaiting " + outstanding + " record processors to complete final shutdown"); - - return checkWorkerShutdownMiss(outstanding); - } - return 0; - } - - private long remainingTimeout(long timeout, TimeUnit unit, long startNanos) { - long checkNanos = System.nanoTime() - startNanos; - return unit.toNanos(timeout) - checkNanos; - } - - private void throwTimeoutMessageIfExceeded(long remainingNanos, String message) throws TimeoutException { - if (remainingNanos <= 0) { - throw new TimeoutException(message); + try { + while (!shutdownCompleteLatch.await(1, TimeUnit.SECONDS)) { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + log.info(awaitingFinalShutdownMessage()); + if (workerShutdownWithRemaining(shutdownCompleteLatch.getCount())) { + return false; + } + } + } catch (InterruptedException ie) { + log.warn("Interrupted while waiting for shutdown completion, terminating shutdown. " + + awaitingFinalShutdownMessage()); + return false; } + return true; } /** @@ -106,24 +127,23 @@ public class RequestedShutdownCoordinator { * * @param outstanding * the number of record processor still awaiting shutdown. - * @return the number of record processors awaiting shutdown, or 0 if the worker believes it's shutdown already. */ - private long checkWorkerShutdownMiss(long outstanding) { + private boolean workerShutdownWithRemaining(long outstanding) { if (isWorkerShutdownComplete()) { if (outstanding != 0) { log.info("Shutdown completed, but shutdownCompleteLatch still had outstanding " + outstanding + " with a current value of " + shutdownCompleteLatch.getCount() + ". shutdownComplete: " + worker.isShutdownComplete() + " -- Consumer Map: " + worker.getShardInfoShardConsumerMap().size()); + return true; } - return 0; } - return outstanding; + return false; } @Override - public Void call() throws Exception { - return null; + public Boolean call() throws Exception { + return waitForRecordProcessors(); } } } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownFuture.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownFuture.java deleted file mode 100644 index 8e530df5..00000000 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownFuture.java +++ /dev/null @@ -1,157 +0,0 @@ -package com.amazonaws.services.kinesis.clientlibrary.lib.worker; - -import java.util.concurrent.Callable; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import com.google.common.util.concurrent.ThreadFactoryBuilder; - -/** - * Used as a response from the {@link Worker#requestShutdown()} to allow callers to wait until shutdown is complete. - */ -class ShutdownFuture { - - private static final Log log = LogFactory.getLog(ShutdownFuture.class); - - private final CountDownLatch shutdownCompleteLatch; - private final CountDownLatch notificationCompleteLatch; - private final Worker worker; - private final ExecutorService shutdownExecutor; - - ShutdownFuture(CountDownLatch shutdownCompleteLatch, CountDownLatch notificationCompleteLatch, Worker worker) { - this(shutdownCompleteLatch, notificationCompleteLatch, worker, makeExecutor()); - } - - ShutdownFuture(CountDownLatch shutdownCompleteLatch, CountDownLatch notificationCompleteLatch, Worker worker, - ExecutorService shutdownExecutor) { - this.shutdownCompleteLatch = shutdownCompleteLatch; - this.notificationCompleteLatch = notificationCompleteLatch; - this.worker = worker; - this.shutdownExecutor = shutdownExecutor; - } - - private static ExecutorService makeExecutor() { - ThreadFactory threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat("RequestShutdown-%04d") - .build(); - return Executors.newSingleThreadExecutor(threadFactory); - } - - private boolean isWorkerShutdownComplete() { - return worker.isShutdownComplete() || worker.getShardInfoShardConsumerMap().isEmpty(); - } - - private long outstandingRecordProcessors(long timeout, TimeUnit unit) - throws InterruptedException, ExecutionException, TimeoutException { - - final long startNanos = System.nanoTime(); - - // - // Awaiting for all ShardConsumer/RecordProcessors to be notified that a shutdown has been requested. - // There is the possibility of a race condition where a lease is terminated after the shutdown request - // notification is started, but before the ShardConsumer is sent the notification. In this case the - // ShardConsumer would start the lease loss shutdown, and may never call the notification methods. - // - if (!notificationCompleteLatch.await(timeout, unit)) { - long awaitingNotification = notificationCompleteLatch.getCount(); - long awaitingFinalShutdown = shutdownCompleteLatch.getCount(); - log.info("Awaiting " + awaitingNotification + " record processors to complete shutdown notification, and " - + awaitingFinalShutdown + " awaiting final shutdown"); - if (awaitingFinalShutdown != 0) { - // - // The number of record processor awaiting final shutdown should be a superset of the those awaiting - // notification - // - return checkWorkerShutdownMiss(awaitingFinalShutdown); - } - } - - long remaining = remainingTimeout(timeout, unit, startNanos); - throwTimeoutMessageIfExceeded(remaining, "Notification hasn't completed within timeout time."); - - // - // Once all record processors have been notified of the shutdown it is safe to allow the worker to - // start its shutdown behavior. Once shutdown starts it will stop renewer, and drop any remaining leases. - // - worker.shutdown(); - remaining = remainingTimeout(timeout, unit, startNanos); - throwTimeoutMessageIfExceeded(remaining, "Shutdown hasn't completed within timeout time."); - - // - // Want to wait for all the remaining ShardConsumers/RecordProcessor's to complete their final shutdown - // processing. This should really be a no-op since as part of the notification completion the lease for - // ShardConsumer is terminated. - // - if (!shutdownCompleteLatch.await(remaining, TimeUnit.NANOSECONDS)) { - long outstanding = shutdownCompleteLatch.getCount(); - log.info("Awaiting " + outstanding + " record processors to complete final shutdown"); - - return checkWorkerShutdownMiss(outstanding); - } - return 0; - } - - private long remainingTimeout(long timeout, TimeUnit unit, long startNanos) { - long checkNanos = System.nanoTime() - startNanos; - return unit.toNanos(timeout) - checkNanos; - } - - private void throwTimeoutMessageIfExceeded(long remainingNanos, String message) throws TimeoutException { - if (remainingNanos <= 0) { - throw new TimeoutException(message); - } - } - - /** - * This checks to see if the worker has already hit it's shutdown target, while there is outstanding record - * processors. This maybe a little racy due to when the value of outstanding is retrieved. In general though the - * latch should be decremented before the shutdown completion. - * - * @param outstanding - * the number of record processor still awaiting shutdown. - * @return the number of record processors awaiting shutdown, or 0 if the worker believes it's shutdown already. - */ - private long checkWorkerShutdownMiss(long outstanding) { - if (isWorkerShutdownComplete()) { - if (outstanding != 0) { - log.info("Shutdown completed, but shutdownCompleteLatch still had outstanding " + outstanding - + " with a current value of " + shutdownCompleteLatch.getCount() + ". shutdownComplete: " - + worker.isShutdownComplete() + " -- Consumer Map: " - + worker.getShardInfoShardConsumerMap().size()); - } - return 0; - } - return outstanding; - } - - Future startShutdown() { - return shutdownExecutor.submit(new ShutdownCallable()); - } - - private class ShutdownCallable implements Callable { - @Override - public Void call() throws Exception { - boolean complete = false; - do { - try { - long outstanding = outstandingRecordProcessors(1, TimeUnit.SECONDS); - complete = outstanding == 0; - log.info("Awaiting " + outstanding + " consumer(s) to finish shutdown."); - } catch (TimeoutException te) { - log.info("Timeout while waiting for completion: " + te.getMessage()); - } - - } while (!complete); - return null; - } - } - -} diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java index 03901c60..563b4432 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java @@ -22,6 +22,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -29,6 +30,7 @@ import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -540,6 +542,44 @@ public class Worker implements Runnable { */ public Future requestShutdown() { + Future requestedShutdownFuture = requestCancellableShutdown(); + + return new Future() { + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return requestedShutdownFuture.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return requestedShutdownFuture.isCancelled(); + } + + @Override + public boolean isDone() { + return requestedShutdownFuture.isDone(); + } + + @Override + public Void get() throws InterruptedException, ExecutionException { + requestedShutdownFuture.get(); + return null; + } + + @Override + public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + requestedShutdownFuture.get(timeout, unit); + return null; + } + }; + } + + public Future requestCancellableShutdown() { + return RequestedShutdownCoordinator.startRequestedShutdown(requestShutdownCallable()); + } + + public Callable requestShutdownCallable() { // // Stop accepting new leases. Once we do this we can be sure that // no more leases will be acquired. @@ -552,7 +592,7 @@ public class Worker implements Runnable { // If there are no leases notification is already completed, but we still need to shutdown the worker. // this.shutdown(); - return Futures.immediateFuture(null); + return () -> true; } CountDownLatch shutdownCompleteLatch = new CountDownLatch(leases.size()); CountDownLatch notificationCompleteLatch = new CountDownLatch(leases.size()); @@ -573,8 +613,7 @@ public class Worker implements Runnable { shutdownCompleteLatch.countDown(); } } - - return new ShutdownFuture(shutdownCompleteLatch, notificationCompleteLatch, this).startShutdown(); + return RequestedShutdownCoordinator.createRequestedShutdownCallable(shutdownCompleteLatch, notificationCompleteLatch, this); } boolean isShutdownComplete() { diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RequestedShutdownCoordinatorTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RequestedShutdownCoordinatorTest.java new file mode 100644 index 00000000..d56e2bd4 --- /dev/null +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RequestedShutdownCoordinatorTest.java @@ -0,0 +1,295 @@ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; +import org.mockito.verification.VerificationMode; + +@RunWith(MockitoJUnitRunner.class) +public class RequestedShutdownCoordinatorTest { + + @Mock + private CountDownLatch shutdownCompleteLatch; + @Mock + private CountDownLatch notificationCompleteLatch; + @Mock + private Worker worker; + @Mock + private ConcurrentMap shardInfoConsumerMap; + + @Test + public void testAllShutdownCompletedAlready() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true); + when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true); + + assertThat(requestedShutdownCallable.call(), equalTo(true)); + verify(shutdownCompleteLatch).await(anyLong(), any(TimeUnit.class)); + verify(notificationCompleteLatch).await(anyLong(), any(TimeUnit.class)); + verify(worker).shutdown(); + } + + @Test + public void testNotificationNotCompletedYet() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + mockLatchAwait(notificationCompleteLatch, false, true); + when(notificationCompleteLatch.getCount()).thenReturn(1L, 0L); + mockLatchAwait(shutdownCompleteLatch, true); + when(shutdownCompleteLatch.getCount()).thenReturn(1L, 1L, 0L); + + when(worker.isShutdownComplete()).thenReturn(false, true); + mockShardInfoConsumerMap(1, 0); + + assertThat(requestedShutdownCallable.call(), equalTo(true)); + verify(notificationCompleteLatch, times(2)).await(anyLong(), any(TimeUnit.class)); + verify(notificationCompleteLatch).getCount(); + + verify(shutdownCompleteLatch).await(anyLong(), any(TimeUnit.class)); + verify(shutdownCompleteLatch, times(2)).getCount(); + + verify(worker).shutdown(); + } + + @Test + public void testShutdownNotCompletedYet() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + mockLatchAwait(notificationCompleteLatch, true); + mockLatchAwait(shutdownCompleteLatch, false, true); + when(shutdownCompleteLatch.getCount()).thenReturn(1L, 0L); + + when(worker.isShutdownComplete()).thenReturn(false, true); + mockShardInfoConsumerMap(1, 0); + + assertThat(requestedShutdownCallable.call(), equalTo(true)); + verify(notificationCompleteLatch).await(anyLong(), any(TimeUnit.class)); + verify(notificationCompleteLatch, never()).getCount(); + + verify(shutdownCompleteLatch, times(2)).await(anyLong(), any(TimeUnit.class)); + verify(shutdownCompleteLatch, times(2)).getCount(); + + verify(worker).shutdown(); + } + + @Test + public void testMultipleAttemptsForNotification() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + mockLatchAwait(notificationCompleteLatch, false, false, true); + when(notificationCompleteLatch.getCount()).thenReturn(2L, 1L, 0L); + + mockLatchAwait(shutdownCompleteLatch, true); + when(shutdownCompleteLatch.getCount()).thenReturn(2L, 2L, 1L, 1L, 0L); + + when(worker.isShutdownComplete()).thenReturn(false, false, false, true); + mockShardInfoConsumerMap(2, 1, 0); + + assertThat(requestedShutdownCallable.call(), equalTo(true)); + + verifyLatchAwait(notificationCompleteLatch, 3); + verify(notificationCompleteLatch, times(2)).getCount(); + + verifyLatchAwait(shutdownCompleteLatch, 1); + verify(shutdownCompleteLatch, times(4)).getCount(); + } + + @Test + public void testWorkerAlreadyShutdownAtNotification() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + mockLatchAwait(notificationCompleteLatch, false, true); + when(notificationCompleteLatch.getCount()).thenReturn(1L, 0L); + + mockLatchAwait(shutdownCompleteLatch, true); + when(shutdownCompleteLatch.getCount()).thenReturn(1L, 1L, 0L); + + when(worker.isShutdownComplete()).thenReturn(true); + mockShardInfoConsumerMap(0); + + assertThat(requestedShutdownCallable.call(), equalTo(false)); + + verifyLatchAwait(notificationCompleteLatch); + verify(notificationCompleteLatch).getCount(); + + verifyLatchAwait(shutdownCompleteLatch, never()); + verify(shutdownCompleteLatch, times(3)).getCount(); + } + + @Test + public void testWorkerAlreadyShutdownAtComplete() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + mockLatchAwait(notificationCompleteLatch, true); + + mockLatchAwait(shutdownCompleteLatch, false, true); + when(shutdownCompleteLatch.getCount()).thenReturn(1L, 1L, 1L); + + when(worker.isShutdownComplete()).thenReturn(true); + mockShardInfoConsumerMap(0); + + assertThat(requestedShutdownCallable.call(), equalTo(false)); + + verifyLatchAwait(notificationCompleteLatch); + verify(notificationCompleteLatch, never()).getCount(); + + verifyLatchAwait(shutdownCompleteLatch); + verify(shutdownCompleteLatch, times(3)).getCount(); + } + + @Test + public void testNotificationInterrupted() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenThrow(new InterruptedException()); + when(notificationCompleteLatch.getCount()).thenReturn(1L); + + when(shutdownCompleteLatch.getCount()).thenReturn(1L); + + assertThat(requestedShutdownCallable.call(), equalTo(false)); + verifyLatchAwait(notificationCompleteLatch); + verifyLatchAwait(shutdownCompleteLatch, never()); + verify(worker, never()).shutdown(); + } + + @Test + public void testShutdownInterrupted() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true); + + when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenThrow(new InterruptedException()); + when(shutdownCompleteLatch.getCount()).thenReturn(1L); + + assertThat(requestedShutdownCallable.call(), equalTo(false)); + verifyLatchAwait(notificationCompleteLatch); + verifyLatchAwait(shutdownCompleteLatch); + verify(worker).shutdown(); + } + + @Test + public void testInterruptedAfterNotification() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenAnswer(invocation -> { + Thread.currentThread().interrupt(); + return true; + }); + + assertThat(requestedShutdownCallable.call(), equalTo(false)); + verifyLatchAwait(notificationCompleteLatch); + verifyLatchAwait(shutdownCompleteLatch, never()); + verify(worker, never()).shutdown(); + } + + @Test + public void testInterruptedAfterWorkerShutdown() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true); + + doAnswer(invocation -> { + Thread.currentThread().interrupt(); + return true; + }).when(worker).shutdown(); + + assertThat(requestedShutdownCallable.call(), equalTo(false)); + verifyLatchAwait(notificationCompleteLatch); + verifyLatchAwait(shutdownCompleteLatch, never()); + verify(worker).shutdown(); + } + + @Test + public void testInterruptedDuringNotification() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenAnswer(invocation -> { + Thread.currentThread().interrupt(); + return false; + }); + when(notificationCompleteLatch.getCount()).thenReturn(1L); + + when(shutdownCompleteLatch.getCount()).thenReturn(1L); + + assertThat(requestedShutdownCallable.call(), equalTo(false)); + verifyLatchAwait(notificationCompleteLatch); + verify(notificationCompleteLatch).getCount(); + + verifyLatchAwait(shutdownCompleteLatch, never()); + verify(shutdownCompleteLatch).getCount(); + + verify(worker, never()).shutdown(); + } + + @Test + public void testInterruptedDuringShutdown() throws Exception { + Callable requestedShutdownCallable = buildRequestedShutdownCallable(); + + when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true); + + when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenAnswer(invocation -> { + Thread.currentThread().interrupt(); + return false; + }); + when(shutdownCompleteLatch.getCount()).thenReturn(1L); + + assertThat(requestedShutdownCallable.call(), equalTo(false)); + verifyLatchAwait(notificationCompleteLatch); + verify(notificationCompleteLatch, never()).getCount(); + + verifyLatchAwait(shutdownCompleteLatch); + verify(shutdownCompleteLatch).getCount(); + + verify(worker).shutdown(); + } + + private void verifyLatchAwait(CountDownLatch latch) throws Exception { + verifyLatchAwait(latch, times(1)); + } + private void verifyLatchAwait(CountDownLatch latch, int times) throws Exception { + verifyLatchAwait(latch, times(times)); + } + + private void verifyLatchAwait(CountDownLatch latch, VerificationMode verificationMode) throws Exception { + verify(latch, verificationMode).await(anyLong(), any(TimeUnit.class)); + } + + private void mockLatchAwait(CountDownLatch latch, Boolean initial, Boolean... remaining) throws Exception { + when(latch.await(anyLong(), any(TimeUnit.class))).thenReturn(initial, remaining); + } + + private Callable buildRequestedShutdownCallable() { + return RequestedShutdownCoordinator.createRequestedShutdownCallable(shutdownCompleteLatch, + notificationCompleteLatch, worker); + } + + private void mockShardInfoConsumerMap(Integer initialItemCount, Integer... additionalItemCounts) { + when(worker.getShardInfoShardConsumerMap()).thenReturn(shardInfoConsumerMap); + Boolean additionalEmptyStates[] = new Boolean[additionalItemCounts.length]; + for (int i = 0; i < additionalItemCounts.length; ++i) { + additionalEmptyStates[i] = additionalItemCounts[i] == 0; + } + when(shardInfoConsumerMap.size()).thenReturn(initialItemCount, additionalItemCounts); + when(shardInfoConsumerMap.isEmpty()).thenReturn(initialItemCount == 0, additionalEmptyStates); + } + +} \ No newline at end of file diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownFutureTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownFutureTest.java deleted file mode 100644 index 3e93a8cd..00000000 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShutdownFutureTest.java +++ /dev/null @@ -1,258 +0,0 @@ -package com.amazonaws.services.kinesis.clientlibrary.lib.worker; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyLong; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.MoreExecutors; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.runners.MockitoJUnitRunner; -import org.mockito.stubbing.Answer; -import org.mockito.stubbing.OngoingStubbing; - -@RunWith(MockitoJUnitRunner.class) -public class ShutdownFutureTest { - - @Mock - private CountDownLatch shutdownCompleteLatch; - @Mock - private CountDownLatch notificationCompleteLatch; - @Mock - private Worker worker; - @Mock - private ConcurrentMap shardInfoConsumerMap; - @Mock - private ExecutorService executorService; - - @Test - public void testSimpleGetAlreadyCompleted() throws Exception { - - - mockNotificationComplete(true); - mockShutdownComplete(true); - - Future future = new ShutdownFuture(shutdownCompleteLatch, notificationCompleteLatch, worker, executorService).startShutdown(); - - future.get(); - - verify(notificationCompleteLatch).await(anyLong(), any(TimeUnit.class)); - verify(worker).shutdown(); - verify(shutdownCompleteLatch).await(anyLong(), any(TimeUnit.class)); - verify(executorService.shutdownNow()); - } - - @Test - public void testNotificationNotCompleted() throws Exception { - ShutdownFuture future = new ShutdownFuture(shutdownCompleteLatch, notificationCompleteLatch, worker, executorService); - - mockNotificationComplete(false, true); - mockShutdownComplete(true); - - when(worker.getShardInfoShardConsumerMap()).thenReturn(shardInfoConsumerMap); - when(shardInfoConsumerMap.isEmpty()).thenReturn(false); - when(worker.isShutdownComplete()).thenReturn(false); - - when(notificationCompleteLatch.getCount()).thenReturn(1L); - when(shutdownCompleteLatch.getCount()).thenReturn(1L); - - expectedTimeoutException(future); - - verify(worker, never()).shutdown(); - - awaitFuture(future); - - verify(notificationCompleteLatch).getCount(); - verifyLatchAwait(notificationCompleteLatch, 2); - - verify(shutdownCompleteLatch).getCount(); - verifyLatchAwait(shutdownCompleteLatch); - - verify(worker).shutdown(); - - } - - @Test - public void testShutdownNotCompleted() throws Exception { - ShutdownFuture future = new ShutdownFuture(shutdownCompleteLatch, notificationCompleteLatch, worker); - mockNotificationComplete(true); - mockShutdownComplete(false, true); - - when(shutdownCompleteLatch.getCount()).thenReturn(1L); - when(worker.isShutdownComplete()).thenReturn(false); - - mockShardInfoConsumerMap(1); - - expectedTimeoutException(future); - verify(worker).shutdown(); - awaitFuture(future); - - verifyLatchAwait(notificationCompleteLatch, 2); - verifyLatchAwait(shutdownCompleteLatch, 2); - - verify(worker).isShutdownComplete(); - verify(worker).getShardInfoShardConsumerMap(); - - } - - @Test - public void testShutdownNotCompleteButWorkerShutdown() throws Exception { - ShutdownFuture future = create(); - - mockNotificationComplete(true); - mockShutdownComplete(false); - - when(shutdownCompleteLatch.getCount()).thenReturn(1L); - when(worker.isShutdownComplete()).thenReturn(true); - mockShardInfoConsumerMap(1); - - awaitFuture(future); - verify(worker).shutdown(); - verifyLatchAwait(notificationCompleteLatch); - verifyLatchAwait(shutdownCompleteLatch); - - verify(worker, times(2)).isShutdownComplete(); - verify(worker).getShardInfoShardConsumerMap(); - verify(shardInfoConsumerMap).size(); - } - - @Test - public void testShutdownNotCompleteButShardConsumerEmpty() throws Exception { - ShutdownFuture future = create(); - mockNotificationComplete(true); - mockShutdownComplete(false); - - mockOutstanding(shutdownCompleteLatch, 1L); - - when(worker.isShutdownComplete()).thenReturn(false); - mockShardInfoConsumerMap(0); - - awaitFuture(future); - verify(worker).shutdown(); - verifyLatchAwait(notificationCompleteLatch); - verifyLatchAwait(shutdownCompleteLatch); - - verify(worker, times(2)).isShutdownComplete(); - verify(worker, times(2)).getShardInfoShardConsumerMap(); - - verify(shardInfoConsumerMap).isEmpty(); - verify(shardInfoConsumerMap).size(); - } - - @Test - public void testNotificationNotCompleteButShardConsumerEmpty() throws Exception { - ShutdownFuture future = create(); - mockNotificationComplete(false); - mockShutdownComplete(false); - - mockOutstanding(notificationCompleteLatch, 1L); - mockOutstanding(shutdownCompleteLatch, 1L); - - when(worker.isShutdownComplete()).thenReturn(false); - mockShardInfoConsumerMap(0); - - awaitFuture(future); - verify(worker, never()).shutdown(); - verifyLatchAwait(notificationCompleteLatch); - verify(shutdownCompleteLatch, never()).await(); - - verify(worker, times(2)).isShutdownComplete(); - verify(worker, times(2)).getShardInfoShardConsumerMap(); - - verify(shardInfoConsumerMap).isEmpty(); - verify(shardInfoConsumerMap).size(); - } - - @Test(expected = TimeoutException.class) - public void testTimeExceededException() throws Exception { - ShutdownFuture future = create(); - mockNotificationComplete(false); - mockOutstanding(notificationCompleteLatch, 1L); - when(worker.isShutdownComplete()).thenReturn(false); - mockShardInfoConsumerMap(1); - - future.get(1, TimeUnit.NANOSECONDS); - } - - private ShutdownFuture create() { - return new ShutdownFuture(shutdownCompleteLatch, notificationCompleteLatch, worker); - } - - private void mockShardInfoConsumerMap(Integer initialItemCount, Integer ... additionalItemCounts) { - when(worker.getShardInfoShardConsumerMap()).thenReturn(shardInfoConsumerMap); - Boolean additionalEmptyStates[] = new Boolean[additionalItemCounts.length]; - for(int i = 0; i < additionalItemCounts.length; ++i) { - additionalEmptyStates[i] = additionalItemCounts[i] == 0; - } - when(shardInfoConsumerMap.size()).thenReturn(initialItemCount, additionalItemCounts); - when(shardInfoConsumerMap.isEmpty()).thenReturn(initialItemCount == 0, additionalEmptyStates); - } - - private void verifyLatchAwait(CountDownLatch latch) throws Exception { - verifyLatchAwait(latch, 1); - } - - private void verifyLatchAwait(CountDownLatch latch, int times) throws Exception { - verify(latch, times(times)).await(anyLong(), any(TimeUnit.class)); - } - - private void expectedTimeoutException(ShutdownFuture future) throws Exception { - boolean gotTimeout = false; - try { - awaitFuture(future); - } catch (TimeoutException te) { - gotTimeout = true; - } - assertThat("Expected a timeout exception to occur", gotTimeout); - } - - private void awaitFuture(Future future) throws Exception { - future.get(1, TimeUnit.SECONDS); - } - - private void mockNotificationComplete(Boolean initial, Boolean... states) throws Exception { - mockLatch(notificationCompleteLatch, initial, states); - - } - - private void mockShutdownComplete(Boolean initial, Boolean... states) throws Exception { - mockLatch(shutdownCompleteLatch, initial, states); - } - - private void mockLatch(CountDownLatch latch, Boolean initial, Boolean... states) throws Exception { - when(latch.await(anyLong(), any(TimeUnit.class))).thenReturn(initial, states); - } - - private void mockOutstanding(CountDownLatch latch, Long remaining, Long ... additionalRemaining) throws Exception { - when(latch.getCount()).thenReturn(remaining, additionalRemaining); - } - - private void mockExecutor() { - when(executorService.submit(any(Callable.class))).thenAnswer(new Answer>() { - @Override - public Future answer(InvocationOnMock invocation) throws Throwable { - Callable callable = (Callable)invocation.getArgumentAt(0, Callable.class); - return Futures.immediateFuture(callable.call()); - } - }) - } - -} \ No newline at end of file