diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java index f8596419..df7fdda4 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java @@ -582,7 +582,8 @@ public class Scheduler implements Runnable { aggregatorUtil, hierarchicalShardSyncer, metricsFactory); - return new ShardConsumer(cache, executorService, shardInfo, lifecycleConfig.logWarningForTaskAfterMillis(), argument); + return new ShardConsumer(cache, executorService, shardInfo, lifecycleConfig.logWarningForTaskAfterMillis(), + argument, lifecycleConfig.taskExecutionListener()); } /** diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/LifecycleConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/LifecycleConfig.java index b91376dd..b04d75ce 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/LifecycleConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/LifecycleConfig.java @@ -46,4 +46,10 @@ public class LifecycleConfig { */ private AggregatorUtil aggregatorUtil = new AggregatorUtil(); + /** + * TaskExecutionListener to be used to handle events during task execution lifecycle for a shard. + * + *

Default value: {@link NoOpTaskExecutionListener}

+ */ + private TaskExecutionListener taskExecutionListener = new NoOpTaskExecutionListener(); } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/NoOpTaskExecutionListener.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/NoOpTaskExecutionListener.java new file mode 100644 index 00000000..95d225fa --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/NoOpTaskExecutionListener.java @@ -0,0 +1,31 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle; + +import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput; + +/** + * NoOp implementation of {@link TaskExecutionListener} interface that takes no action on task execution. + */ +public class NoOpTaskExecutionListener implements TaskExecutionListener { + @Override + public void beforeTaskExecution(TaskExecutionListenerInput input) { + } + + @Override + public void afterTaskExecution(TaskExecutionListenerInput input) { + } +} + diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java index 8be5ec82..f386d48c 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java @@ -40,6 +40,7 @@ import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.exceptions.internal.BlockedOnParentShardException; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; +import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput; import software.amazon.kinesis.metrics.MetricsCollectingTaskDecorator; import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.retrieval.RecordsPublisher; @@ -66,6 +67,7 @@ public class ShardConsumer { private final Optional logWarningForTaskAfterMillis; private final Function taskMetricsDecorator; private final int bufferSize; + private final TaskExecutionListener taskExecutionListener; private ConsumerTask currentTask; private TaskOutcome taskOutcome; @@ -95,10 +97,11 @@ public class ShardConsumer { private final InternalSubscriber subscriber; public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo, - Optional logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument) { + Optional logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument, + TaskExecutionListener taskExecutionListener) { this(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, shardConsumerArgument, ConsumerStates.INITIAL_STATE, - ShardConsumer.metricsWrappingFunction(shardConsumerArgument.metricsFactory()), 8); + ShardConsumer.metricsWrappingFunction(shardConsumerArgument.metricsFactory()), 8, taskExecutionListener); } // @@ -106,12 +109,14 @@ public class ShardConsumer { // public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo, Optional logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument, - ConsumerState initialState, Function taskMetricsDecorator, int bufferSize) { + ConsumerState initialState, Function taskMetricsDecorator, + int bufferSize, TaskExecutionListener taskExecutionListener) { this.recordsPublisher = recordsPublisher; this.executorService = executorService; this.shardInfo = shardInfo; this.shardConsumerArgument = shardConsumerArgument; this.logWarningForTaskAfterMillis = logWarningForTaskAfterMillis; + this.taskExecutionListener = taskExecutionListener; this.currentState = initialState; this.taskMetricsDecorator = taskMetricsDecorator; scheduler = Schedulers.from(executorService); @@ -379,6 +384,11 @@ public class ShardConsumer { } private synchronized void executeTask(ProcessRecordsInput input) { + TaskExecutionListenerInput taskExecutionListenerInput = TaskExecutionListenerInput.builder() + .shardInfo(shardInfo) + .taskType(currentState.taskType()) + .build(); + taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput); ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input); if (task != null) { taskDispatchedAt = Instant.now(); @@ -391,7 +401,9 @@ public class ShardConsumer { taskIsRunning = false; } taskOutcome = resultToOutcome(result); + taskExecutionListenerInput = taskExecutionListenerInput.toBuilder().taskOutcome(taskOutcome).build(); } + taskExecutionListener.afterTaskExecution(taskExecutionListenerInput); } private TaskOutcome resultToOutcome(TaskResult result) { @@ -435,10 +447,6 @@ public class ShardConsumer { return nextState; } - private enum TaskOutcome { - SUCCESSFUL, END_OF_SHARD, FAILURE - } - private void logTaskException(TaskResult taskResult) { if (log.isDebugEnabled()) { Exception taskException = taskResult.getException(); diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskExecutionListener.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskExecutionListener.java new file mode 100644 index 00000000..b70a6103 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskExecutionListener.java @@ -0,0 +1,31 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle; + +import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput; + +/** + * A listener for callbacks on task execution lifecycle for for a shard. + * + * Note: Recommended not to have a blocking implementation since these methods are + * called around the ShardRecordProcessor. A blocking call would result in slowing + * down the ShardConsumer. + */ +public interface TaskExecutionListener { + + void beforeTaskExecution(TaskExecutionListenerInput input); + + void afterTaskExecution(TaskExecutionListenerInput input); +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskOutcome.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskOutcome.java new file mode 100644 index 00000000..832137fc --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskOutcome.java @@ -0,0 +1,33 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle; + +/** + * Enumerates types of outcome of tasks executed as part of processing a shard. + */ +public enum TaskOutcome { + /** + * Denotes a successful task outcome. + */ + SUCCESSFUL, + /** + * Denotes that the last record from the shard has been read/consumed. + */ + END_OF_SHARD, + /** + * Denotes a failure or exception during processing of the shard. + */ + FAILURE +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/TaskExecutionListenerInput.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/TaskExecutionListenerInput.java new file mode 100644 index 00000000..b64addb5 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/TaskExecutionListenerInput.java @@ -0,0 +1,48 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle.events; + +import lombok.Builder; +import lombok.Data; +import lombok.experimental.Accessors; +import software.amazon.kinesis.leases.ShardInfo; +import software.amazon.kinesis.lifecycle.TaskOutcome; +import software.amazon.kinesis.lifecycle.TaskType; +import software.amazon.kinesis.lifecycle.TaskExecutionListener; + +/** + * Container for the parameters to the TaskExecutionListener's + * {@link TaskExecutionListener#beforeTaskExecution(TaskExecutionListenerInput)} method. + * {@link TaskExecutionListener#afterTaskExecution(TaskExecutionListenerInput)} method. + */ +@Data +@Builder(toBuilder = true) +@Accessors(fluent = true) +public class TaskExecutionListenerInput { + /** + * Detailed information about the shard whose progress is monitored by TaskExecutionListener. + */ + private final ShardInfo shardInfo; + /** + * The type of task being executed for the shard. + * + * This corresponds to the state the shard is in. + */ + private final TaskType taskType; + /** + * Outcome of the task execution for the shard. + */ + private final TaskOutcome taskOutcome; +} diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java index f41d773b..9382b491 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java @@ -92,6 +92,8 @@ public class ConsumerStatesTest { private MetricsFactory metricsFactory; @Mock private ProcessRecordsInput processRecordsInput; + @Mock + private TaskExecutionListener taskExecutionListener; private long parentShardPollIntervalMillis = 0xCAFE; private boolean cleanupLeasesOfCompletedShards = true; @@ -115,7 +117,7 @@ public class ConsumerStatesTest { cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, shardDetector, new AggregatorUtil(), hierarchicalShardSyncer, metricsFactory); consumer = spy( - new ShardConsumer(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, argument)); + new ShardConsumer(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, argument, taskExecutionListener)); when(shardInfo.shardId()).thenReturn("shardId-000000000000"); when(recordProcessorCheckpointer.checkpointer()).thenReturn(checkpointer); diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java index 42f7a522..114d4d47 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java @@ -30,6 +30,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import java.time.Instant; @@ -66,6 +68,7 @@ import lombok.extern.slf4j.Slf4j; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; +import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput; import software.amazon.kinesis.retrieval.RecordsPublisher; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; @@ -79,6 +82,11 @@ public class ShardConsumerTest { private final String shardId = "shardId-0-0"; private final String concurrencyToken = "TestToken"; private ShardInfo shardInfo; + private TaskExecutionListenerInput initialTaskInput; + private TaskExecutionListenerInput processTaskInput; + private TaskExecutionListenerInput shutdownTaskInput; + private TaskExecutionListenerInput shutdownRequestedTaskInput; + private TaskExecutionListenerInput shutdownRequestedAwaitTaskInput; private ExecutorService executorService; @Mock @@ -111,6 +119,8 @@ public class ShardConsumerTest { private ConsumerTask shutdownRequestedTask; @Mock private ConsumerState shutdownRequestedAwaitState; + @Mock + private TaskExecutionListener taskExecutionListener; private ProcessRecordsInput processRecordsInput; @@ -128,6 +138,16 @@ public class ShardConsumerTest { processRecordsInput = ProcessRecordsInput.builder().isAtShardEnd(false).cacheEntryTime(Instant.now()) .millisBehindLatest(1000L).records(Collections.emptyList()).build(); + initialTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo) + .taskType(TaskType.INITIALIZE).build(); + processTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo) + .taskType(TaskType.PROCESS).build(); + shutdownRequestedTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo) + .taskType(TaskType.SHUTDOWN_NOTIFICATION).build(); + shutdownRequestedAwaitTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo) + .taskType(TaskType.SHUTDOWN_COMPLETE).build(); + shutdownTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo) + .taskType(TaskType.SHUTDOWN).build(); } @After @@ -219,7 +239,7 @@ public class ShardConsumerTest { TestPublisher cache = new TestPublisher(); ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis, - shardConsumerArgument, initialState, Function.identity(), 1); + shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener); boolean initComplete = false; while (!initComplete) { @@ -246,9 +266,22 @@ public class ShardConsumerTest { verify(cache.subscription, times(3)).request(anyLong()); verify(cache.subscription).cancel(); verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any()); + verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput); + verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput); + initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + + verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput); + verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput); + verifyNoMoreInteractions(taskExecutionListener); } + + @Test public void testDataArrivesAfterProcessing2() throws Exception { @@ -262,7 +295,7 @@ public class ShardConsumerTest { TestPublisher cache = new TestPublisher(); ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis, - shardConsumerArgument, initialState, Function.identity(), 1); + shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener); boolean initComplete = false; while (!initComplete) { @@ -302,6 +335,18 @@ public class ShardConsumerTest { verify(processingTask, times(3)).call(); verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); + verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(3)).beforeTaskExecution(processTaskInput); + verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput); + + initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + + verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(3)).afterTaskExecution(processTaskInput); + verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput); + verifyNoMoreInteractions(taskExecutionListener); } @SuppressWarnings("unchecked") @@ -309,7 +354,7 @@ public class ShardConsumerTest { @Ignore public final void testInitializationStateUponFailure() throws Exception { ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1); + logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener); when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any())).thenReturn(initializeTask); when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad"))); @@ -342,7 +387,7 @@ public class ShardConsumerTest { ExecutorService failingService = mock(ExecutorService.class); ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1); + logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener); doThrow(new RejectedExecutionException()).when(failingService).execute(any()); @@ -350,15 +395,16 @@ public class ShardConsumerTest { do { initComplete = consumer.initializeComplete().get(); } while (!initComplete); - + verifyZeroInteractions(taskExecutionListener); } @Test public void testErrorThrowableInInitialization() throws Exception { ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1); + logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener); when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask); + when(initialState.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.call()).thenAnswer(i -> { throw new Error("Error"); }); @@ -368,6 +414,8 @@ public class ShardConsumerTest { } catch (ExecutionException ee) { assertThat(ee.getCause(), instanceOf(Error.class)); } + verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); + verifyNoMoreInteractions(taskExecutionListener); } @Test @@ -377,7 +425,7 @@ public class ShardConsumerTest { TestPublisher cache = new TestPublisher(); ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis, - shardConsumerArgument, initialState, t -> t, 1); + shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener); mockSuccessfulInitialize(null); @@ -386,6 +434,7 @@ public class ShardConsumerTest { when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState); when(shutdownRequestedState.requiresDataAvailability()).thenReturn(false); when(shutdownRequestedState.createTask(any(), any(), any())).thenReturn(shutdownRequestedTask); + when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION); when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null)); when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED))) @@ -396,6 +445,7 @@ public class ShardConsumerTest { when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.REQUESTED))) .thenReturn(shutdownRequestedState); when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST))).thenReturn(shutdownState); + when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE); mockSuccessfulShutdown(null); @@ -433,7 +483,24 @@ public class ShardConsumerTest { verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED)); verify(shutdownRequestedAwaitState).createTask(any(), any(), any()); verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); + verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput); + verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownRequestedTaskInput); + verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownRequestedAwaitTaskInput); + verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput); + initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + shutdownRequestedTaskInput = shutdownRequestedTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + // No task is created/run for this shutdownRequestedAwaitState, so there's no task outcome. + + verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput); + verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownRequestedTaskInput); + verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownRequestedAwaitTaskInput); + verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput); + verifyNoMoreInteractions(taskExecutionListener); } @Test @@ -441,7 +508,7 @@ public class ShardConsumerTest { TestPublisher cache = new TestPublisher(); ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L), - shardConsumerArgument, initialState, Function.identity(), 1); + shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener); mockSuccessfulInitialize(null); mockSuccessfulProcessing(null); @@ -473,6 +540,13 @@ public class ShardConsumerTest { assertThat(healthCheckOutcome, equalTo(expectedException)); verify(cache.subscription, times(2)).request(anyLong()); + verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput); + + initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + + verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput); + verifyNoMoreInteractions(taskExecutionListener); } @Test @@ -481,7 +555,7 @@ public class ShardConsumerTest { TestPublisher cache = new TestPublisher(); ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L), - shardConsumerArgument, initialState, Function.identity(), 1); + shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener); CyclicBarrier taskArriveBarrier = new CyclicBarrier(2); CyclicBarrier taskDepartBarrier = new CyclicBarrier(2); @@ -551,6 +625,19 @@ public class ShardConsumerTest { assertThat(consumer.taskRunningTime(), nullValue()); consumer.healthCheck(); + + verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput); + verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput); + + initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build(); + + verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput); + verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput); + verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput); + verifyNoMoreInteractions(taskExecutionListener); } private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) { @@ -559,6 +646,7 @@ public class ShardConsumerTest { private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) { when(shutdownState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(shutdownTask); + when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN); when(shutdownTask.taskType()).thenReturn(TaskType.SHUTDOWN); when(shutdownTask.call()).thenAnswer(i -> { awaitBarrier(taskArriveBarrier); @@ -578,6 +666,7 @@ public class ShardConsumerTest { private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) { when(processingState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(processingTask); when(processingState.requiresDataAvailability()).thenReturn(true); + when(processingState.taskType()).thenReturn(TaskType.PROCESS); when(processingTask.taskType()).thenReturn(TaskType.PROCESS); when(processingTask.call()).thenAnswer(i -> { awaitBarrier(taskCallBarrier); @@ -597,6 +686,7 @@ public class ShardConsumerTest { private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) { when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask); + when(initialState.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.call()).thenAnswer(i -> { awaitBarrier(taskCallBarrier);