diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java index 7b3e8ca1..644f4225 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java @@ -445,7 +445,6 @@ public class Worker implements Runnable { if (shutdown) { return; } - workerStateChangeListener.onWorkerStateChange(WorkerStateChangeListener.WorkerState.INITIALIZING); try { initialize(); @@ -455,7 +454,6 @@ public class Worker implements Runnable { shutdown(); } - workerStateChangeListener.onWorkerStateChange(WorkerStateChangeListener.WorkerState.STARTED); while (!shouldShutdown()) { runProcessLoop(); } @@ -501,6 +499,7 @@ public class Worker implements Runnable { } private void initialize() { + workerStateChangeListener.onWorkerStateChange(WorkerStateChangeListener.WorkerState.INITIALIZING); boolean isDone = false; Exception lastException = null; @@ -550,6 +549,7 @@ public class Worker implements Runnable { if (!isDone) { throw new RuntimeException(lastException); } + workerStateChangeListener.onWorkerStateChange(WorkerStateChangeListener.WorkerState.STARTED); } /** @@ -786,7 +786,6 @@ public class Worker implements Runnable { LOG.warn("Shutdown requested a second time."); return; } - workerStateChangeListener.onWorkerStateChange(WorkerStateChangeListener.WorkerState.SHUTTING_DOWN); LOG.info("Worker shutdown requested."); // Set shutdown flag, so Worker.run can start shutdown process. @@ -797,6 +796,7 @@ public class Worker implements Runnable { // Lost leases will force Worker to begin shutdown process for all shard consumers in // Worker.run(). leaseCoordinator.stop(); + workerStateChangeListener.onWorkerStateChange(WorkerStateChangeListener.WorkerState.SHUT_DOWN); } /** @@ -813,7 +813,6 @@ public class Worker implements Runnable { if (metricsFactory instanceof WorkerCWMetricsFactory) { ((CWMetricsFactory) metricsFactory).shutdown(); } - workerStateChangeListener.onWorkerStateChange(WorkerStateChangeListener.WorkerState.SHUT_DOWN); shutdownComplete = true; } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerStateChangeListener.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerStateChangeListener.java index 50340af4..36ee39f0 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerStateChangeListener.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerStateChangeListener.java @@ -9,7 +9,6 @@ public interface WorkerStateChangeListener { CREATED, INITIALIZING, STARTED, - SHUTTING_DOWN, SHUT_DOWN } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java index f71ed0b4..21aaa8ac 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java @@ -22,6 +22,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; @@ -89,6 +90,7 @@ import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcess import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessorFactory; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker.WorkerCWMetricsFactory; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker.WorkerThreadPoolExecutor; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.WorkerStateChangeListener.WorkerState; import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxy; import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisLocalFileProxy; @@ -170,6 +172,8 @@ public class WorkerTest { private Future taskFuture; @Mock private TaskResult taskResult; + @Mock + private WorkerStateChangeListener workerStateChangeListener; @Before public void setup() { @@ -1510,6 +1514,95 @@ public class WorkerTest { Assert.assertTrue(worker.getWorkerStateChangeListener() instanceof NoOpWorkerStateChangeListener); } + @Test + public void testBuilderWhenWorkerStateListenerIsSet() { + IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); + Worker worker = new Worker.Builder() + .recordProcessorFactory(recordProcessorFactory) + .workerStateChangeListener(workerStateChangeListener) + .config(config) + .build(); + Assert.assertSame(workerStateChangeListener, worker.getWorkerStateChangeListener()); + } + + @Test + public void testWorkerStateListenerStatePassesThroughCreatedState() { + IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); + new Worker.Builder() + .recordProcessorFactory(recordProcessorFactory) + .workerStateChangeListener(workerStateChangeListener) + .config(config) + .build(); + + verify(workerStateChangeListener, times(1)).onWorkerStateChange(eq(WorkerState.CREATED)); + } + + @Test + public void testWorkerStateChangeListenerGoesThroughStates() throws Exception { + + final CountDownLatch workerInitialized = new CountDownLatch(1); + final CountDownLatch workerStarted = new CountDownLatch(1); + final IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); + final IRecordProcessor processor = mock(IRecordProcessor.class); + + ExtendedSequenceNumber checkpoint = new ExtendedSequenceNumber("123", 0L); + KinesisClientLeaseBuilder builder = new KinesisClientLeaseBuilder().withCheckpoint(checkpoint) + .withConcurrencyToken(UUID.randomUUID()).withLastCounterIncrementNanos(0L).withLeaseCounter(0L) + .withOwnerSwitchesSinceCheckpoint(0L).withLeaseOwner("Self"); + final List leases = new ArrayList<>(); + KinesisClientLease lease = builder.withLeaseKey(String.format("shardId-%03d", 1)).build(); + leases.add(lease); + + doAnswer(new Answer() { + @Override + public Boolean answer(InvocationOnMock invocation) throws Throwable { + workerInitialized.countDown(); + return true; + } + }).when(leaseManager).waitUntilLeaseTableExists(anyLong(), anyLong()); + doAnswer(new Answer() { + @Override + public IRecordProcessor answer(InvocationOnMock invocation) throws Throwable { + workerStarted.countDown(); + return processor; + } + }).when(recordProcessorFactory).createProcessor(); + + when(config.getWorkerIdentifier()).thenReturn("Self"); + when(leaseManager.listLeases()).thenReturn(leases); + when(leaseManager.renewLease(leases.get(0))).thenReturn(true); + when(executorService.submit(Matchers.> any())) + .thenAnswer(new ShutdownHandlingAnswer(taskFuture)); + when(taskFuture.isDone()).thenReturn(true); + when(taskFuture.get()).thenReturn(taskResult); + when(taskResult.isShardEndReached()).thenReturn(true); + + Worker worker = new Worker.Builder() + .recordProcessorFactory(recordProcessorFactory) + .config(config) + .leaseManager(leaseManager) + .kinesisProxy(kinesisProxy) + .execService(executorService) + .workerStateChangeListener(workerStateChangeListener) + .build(); + + verify(workerStateChangeListener, times(1)).onWorkerStateChange(eq(WorkerState.CREATED)); + + WorkerThread workerThread = new WorkerThread(worker); + workerThread.start(); + + workerInitialized.await(); + verify(workerStateChangeListener, times(1)).onWorkerStateChange(eq(WorkerState.INITIALIZING)); + + workerStarted.await(); + verify(workerStateChangeListener, times(1)).onWorkerStateChange(eq(WorkerState.STARTED)); + + boolean workerShutdown = worker.createGracefulShutdownCallable() + .call(); + + verify(workerStateChangeListener, times(1)).onWorkerStateChange(eq(WorkerState.SHUT_DOWN)); + } + @Test public void testBuilderWithDefaultLeaseManager() { IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class);