diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java
index f8596419..df7fdda4 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java
@@ -582,7 +582,8 @@ public class Scheduler implements Runnable {
aggregatorUtil,
hierarchicalShardSyncer,
metricsFactory);
- return new ShardConsumer(cache, executorService, shardInfo, lifecycleConfig.logWarningForTaskAfterMillis(), argument);
+ return new ShardConsumer(cache, executorService, shardInfo, lifecycleConfig.logWarningForTaskAfterMillis(),
+ argument, lifecycleConfig.taskExecutionListener());
}
/**
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/LifecycleConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/LifecycleConfig.java
index b91376dd..b04d75ce 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/LifecycleConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/LifecycleConfig.java
@@ -46,4 +46,10 @@ public class LifecycleConfig {
*/
private AggregatorUtil aggregatorUtil = new AggregatorUtil();
+ /**
+ * TaskExecutionListener to be used to handle events during task execution lifecycle for a shard.
+ *
+ *
Default value: {@link NoOpTaskExecutionListener}
+ */
+ private TaskExecutionListener taskExecutionListener = new NoOpTaskExecutionListener();
}
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/NoOpTaskExecutionListener.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/NoOpTaskExecutionListener.java
new file mode 100644
index 00000000..95d225fa
--- /dev/null
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/NoOpTaskExecutionListener.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Amazon Software License (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/asl/
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+package software.amazon.kinesis.lifecycle;
+
+import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
+
+/**
+ * NoOp implementation of {@link TaskExecutionListener} interface that takes no action on task execution.
+ */
+public class NoOpTaskExecutionListener implements TaskExecutionListener {
+ @Override
+ public void beforeTaskExecution(TaskExecutionListenerInput input) {
+ }
+
+ @Override
+ public void afterTaskExecution(TaskExecutionListenerInput input) {
+ }
+}
+
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java
index 8be5ec82..f386d48c 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java
@@ -40,6 +40,7 @@ import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.exceptions.internal.BlockedOnParentShardException;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
+import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
import software.amazon.kinesis.metrics.MetricsCollectingTaskDecorator;
import software.amazon.kinesis.metrics.MetricsFactory;
import software.amazon.kinesis.retrieval.RecordsPublisher;
@@ -66,6 +67,7 @@ public class ShardConsumer {
private final Optional logWarningForTaskAfterMillis;
private final Function taskMetricsDecorator;
private final int bufferSize;
+ private final TaskExecutionListener taskExecutionListener;
private ConsumerTask currentTask;
private TaskOutcome taskOutcome;
@@ -95,10 +97,11 @@ public class ShardConsumer {
private final InternalSubscriber subscriber;
public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo,
- Optional logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument) {
+ Optional logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument,
+ TaskExecutionListener taskExecutionListener) {
this(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, shardConsumerArgument,
ConsumerStates.INITIAL_STATE,
- ShardConsumer.metricsWrappingFunction(shardConsumerArgument.metricsFactory()), 8);
+ ShardConsumer.metricsWrappingFunction(shardConsumerArgument.metricsFactory()), 8, taskExecutionListener);
}
//
@@ -106,12 +109,14 @@ public class ShardConsumer {
//
public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo,
Optional logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument,
- ConsumerState initialState, Function taskMetricsDecorator, int bufferSize) {
+ ConsumerState initialState, Function taskMetricsDecorator,
+ int bufferSize, TaskExecutionListener taskExecutionListener) {
this.recordsPublisher = recordsPublisher;
this.executorService = executorService;
this.shardInfo = shardInfo;
this.shardConsumerArgument = shardConsumerArgument;
this.logWarningForTaskAfterMillis = logWarningForTaskAfterMillis;
+ this.taskExecutionListener = taskExecutionListener;
this.currentState = initialState;
this.taskMetricsDecorator = taskMetricsDecorator;
scheduler = Schedulers.from(executorService);
@@ -379,6 +384,11 @@ public class ShardConsumer {
}
private synchronized void executeTask(ProcessRecordsInput input) {
+ TaskExecutionListenerInput taskExecutionListenerInput = TaskExecutionListenerInput.builder()
+ .shardInfo(shardInfo)
+ .taskType(currentState.taskType())
+ .build();
+ taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput);
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input);
if (task != null) {
taskDispatchedAt = Instant.now();
@@ -391,7 +401,9 @@ public class ShardConsumer {
taskIsRunning = false;
}
taskOutcome = resultToOutcome(result);
+ taskExecutionListenerInput = taskExecutionListenerInput.toBuilder().taskOutcome(taskOutcome).build();
}
+ taskExecutionListener.afterTaskExecution(taskExecutionListenerInput);
}
private TaskOutcome resultToOutcome(TaskResult result) {
@@ -435,10 +447,6 @@ public class ShardConsumer {
return nextState;
}
- private enum TaskOutcome {
- SUCCESSFUL, END_OF_SHARD, FAILURE
- }
-
private void logTaskException(TaskResult taskResult) {
if (log.isDebugEnabled()) {
Exception taskException = taskResult.getException();
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskExecutionListener.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskExecutionListener.java
new file mode 100644
index 00000000..b70a6103
--- /dev/null
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskExecutionListener.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Amazon Software License (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/asl/
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+package software.amazon.kinesis.lifecycle;
+
+import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
+
+/**
+ * A listener for callbacks on task execution lifecycle for for a shard.
+ *
+ * Note: Recommended not to have a blocking implementation since these methods are
+ * called around the ShardRecordProcessor. A blocking call would result in slowing
+ * down the ShardConsumer.
+ */
+public interface TaskExecutionListener {
+
+ void beforeTaskExecution(TaskExecutionListenerInput input);
+
+ void afterTaskExecution(TaskExecutionListenerInput input);
+}
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskOutcome.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskOutcome.java
new file mode 100644
index 00000000..832137fc
--- /dev/null
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/TaskOutcome.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Amazon Software License (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/asl/
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+package software.amazon.kinesis.lifecycle;
+
+/**
+ * Enumerates types of outcome of tasks executed as part of processing a shard.
+ */
+public enum TaskOutcome {
+ /**
+ * Denotes a successful task outcome.
+ */
+ SUCCESSFUL,
+ /**
+ * Denotes that the last record from the shard has been read/consumed.
+ */
+ END_OF_SHARD,
+ /**
+ * Denotes a failure or exception during processing of the shard.
+ */
+ FAILURE
+}
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/TaskExecutionListenerInput.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/TaskExecutionListenerInput.java
new file mode 100644
index 00000000..b64addb5
--- /dev/null
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/TaskExecutionListenerInput.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Amazon Software License (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/asl/
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+package software.amazon.kinesis.lifecycle.events;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.experimental.Accessors;
+import software.amazon.kinesis.leases.ShardInfo;
+import software.amazon.kinesis.lifecycle.TaskOutcome;
+import software.amazon.kinesis.lifecycle.TaskType;
+import software.amazon.kinesis.lifecycle.TaskExecutionListener;
+
+/**
+ * Container for the parameters to the TaskExecutionListener's
+ * {@link TaskExecutionListener#beforeTaskExecution(TaskExecutionListenerInput)} method.
+ * {@link TaskExecutionListener#afterTaskExecution(TaskExecutionListenerInput)} method.
+ */
+@Data
+@Builder(toBuilder = true)
+@Accessors(fluent = true)
+public class TaskExecutionListenerInput {
+ /**
+ * Detailed information about the shard whose progress is monitored by TaskExecutionListener.
+ */
+ private final ShardInfo shardInfo;
+ /**
+ * The type of task being executed for the shard.
+ *
+ * This corresponds to the state the shard is in.
+ */
+ private final TaskType taskType;
+ /**
+ * Outcome of the task execution for the shard.
+ */
+ private final TaskOutcome taskOutcome;
+}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java
index f41d773b..9382b491 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java
@@ -92,6 +92,8 @@ public class ConsumerStatesTest {
private MetricsFactory metricsFactory;
@Mock
private ProcessRecordsInput processRecordsInput;
+ @Mock
+ private TaskExecutionListener taskExecutionListener;
private long parentShardPollIntervalMillis = 0xCAFE;
private boolean cleanupLeasesOfCompletedShards = true;
@@ -115,7 +117,7 @@ public class ConsumerStatesTest {
cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, shardDetector, new AggregatorUtil(),
hierarchicalShardSyncer, metricsFactory);
consumer = spy(
- new ShardConsumer(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, argument));
+ new ShardConsumer(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, argument, taskExecutionListener));
when(shardInfo.shardId()).thenReturn("shardId-000000000000");
when(recordProcessorCheckpointer.checkpointer()).thenReturn(checkpointer);
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
index 42f7a522..114d4d47 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
@@ -30,6 +30,8 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import java.time.Instant;
@@ -66,6 +68,7 @@ import lombok.extern.slf4j.Slf4j;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
+import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
@@ -79,6 +82,11 @@ public class ShardConsumerTest {
private final String shardId = "shardId-0-0";
private final String concurrencyToken = "TestToken";
private ShardInfo shardInfo;
+ private TaskExecutionListenerInput initialTaskInput;
+ private TaskExecutionListenerInput processTaskInput;
+ private TaskExecutionListenerInput shutdownTaskInput;
+ private TaskExecutionListenerInput shutdownRequestedTaskInput;
+ private TaskExecutionListenerInput shutdownRequestedAwaitTaskInput;
private ExecutorService executorService;
@Mock
@@ -111,6 +119,8 @@ public class ShardConsumerTest {
private ConsumerTask shutdownRequestedTask;
@Mock
private ConsumerState shutdownRequestedAwaitState;
+ @Mock
+ private TaskExecutionListener taskExecutionListener;
private ProcessRecordsInput processRecordsInput;
@@ -128,6 +138,16 @@ public class ShardConsumerTest {
processRecordsInput = ProcessRecordsInput.builder().isAtShardEnd(false).cacheEntryTime(Instant.now())
.millisBehindLatest(1000L).records(Collections.emptyList()).build();
+ initialTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
+ .taskType(TaskType.INITIALIZE).build();
+ processTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
+ .taskType(TaskType.PROCESS).build();
+ shutdownRequestedTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
+ .taskType(TaskType.SHUTDOWN_NOTIFICATION).build();
+ shutdownRequestedAwaitTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
+ .taskType(TaskType.SHUTDOWN_COMPLETE).build();
+ shutdownTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
+ .taskType(TaskType.SHUTDOWN).build();
}
@After
@@ -219,7 +239,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
- shardConsumerArgument, initialState, Function.identity(), 1);
+ shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
boolean initComplete = false;
while (!initComplete) {
@@ -246,9 +266,22 @@ public class ShardConsumerTest {
verify(cache.subscription, times(3)).request(anyLong());
verify(cache.subscription).cancel();
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any());
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
+ initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+
+ verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput);
+ verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput);
+ verifyNoMoreInteractions(taskExecutionListener);
}
+
+
@Test
public void testDataArrivesAfterProcessing2() throws Exception {
@@ -262,7 +295,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
- shardConsumerArgument, initialState, Function.identity(), 1);
+ shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
boolean initComplete = false;
while (!initComplete) {
@@ -302,6 +335,18 @@ public class ShardConsumerTest {
verify(processingTask, times(3)).call();
verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(3)).beforeTaskExecution(processTaskInput);
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
+
+ initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+
+ verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(3)).afterTaskExecution(processTaskInput);
+ verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput);
+ verifyNoMoreInteractions(taskExecutionListener);
}
@SuppressWarnings("unchecked")
@@ -309,7 +354,7 @@ public class ShardConsumerTest {
@Ignore
public final void testInitializationStateUponFailure() throws Exception {
ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1);
+ logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any())).thenReturn(initializeTask);
when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad")));
@@ -342,7 +387,7 @@ public class ShardConsumerTest {
ExecutorService failingService = mock(ExecutorService.class);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1);
+ logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener);
doThrow(new RejectedExecutionException()).when(failingService).execute(any());
@@ -350,15 +395,16 @@ public class ShardConsumerTest {
do {
initComplete = consumer.initializeComplete().get();
} while (!initComplete);
-
+ verifyZeroInteractions(taskExecutionListener);
}
@Test
public void testErrorThrowableInInitialization() throws Exception {
ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1);
+ logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener);
when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask);
+ when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> {
throw new Error("Error");
});
@@ -368,6 +414,8 @@ public class ShardConsumerTest {
} catch (ExecutionException ee) {
assertThat(ee.getCause(), instanceOf(Error.class));
}
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
+ verifyNoMoreInteractions(taskExecutionListener);
}
@Test
@@ -377,7 +425,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
- shardConsumerArgument, initialState, t -> t, 1);
+ shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener);
mockSuccessfulInitialize(null);
@@ -386,6 +434,7 @@ public class ShardConsumerTest {
when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState);
when(shutdownRequestedState.requiresDataAvailability()).thenReturn(false);
when(shutdownRequestedState.createTask(any(), any(), any())).thenReturn(shutdownRequestedTask);
+ when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION);
when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null));
when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED)))
@@ -396,6 +445,7 @@ public class ShardConsumerTest {
when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.REQUESTED)))
.thenReturn(shutdownRequestedState);
when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST))).thenReturn(shutdownState);
+ when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE);
mockSuccessfulShutdown(null);
@@ -433,7 +483,24 @@ public class ShardConsumerTest {
verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED));
verify(shutdownRequestedAwaitState).createTask(any(), any(), any());
verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownRequestedTaskInput);
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownRequestedAwaitTaskInput);
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
+ initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ shutdownRequestedTaskInput = shutdownRequestedTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ // No task is created/run for this shutdownRequestedAwaitState, so there's no task outcome.
+
+ verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput);
+ verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownRequestedTaskInput);
+ verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownRequestedAwaitTaskInput);
+ verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput);
+ verifyNoMoreInteractions(taskExecutionListener);
}
@Test
@@ -441,7 +508,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L),
- shardConsumerArgument, initialState, Function.identity(), 1);
+ shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
mockSuccessfulInitialize(null);
mockSuccessfulProcessing(null);
@@ -473,6 +540,13 @@ public class ShardConsumerTest {
assertThat(healthCheckOutcome, equalTo(expectedException));
verify(cache.subscription, times(2)).request(anyLong());
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput);
+
+ initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+
+ verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
+ verifyNoMoreInteractions(taskExecutionListener);
}
@Test
@@ -481,7 +555,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L),
- shardConsumerArgument, initialState, Function.identity(), 1);
+ shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
CyclicBarrier taskArriveBarrier = new CyclicBarrier(2);
CyclicBarrier taskDepartBarrier = new CyclicBarrier(2);
@@ -551,6 +625,19 @@ public class ShardConsumerTest {
assertThat(consumer.taskRunningTime(), nullValue());
consumer.healthCheck();
+
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
+ verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
+
+ initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+ shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
+
+ verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
+ verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput);
+ verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput);
+ verifyNoMoreInteractions(taskExecutionListener);
}
private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) {
@@ -559,6 +646,7 @@ public class ShardConsumerTest {
private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) {
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(shutdownTask);
+ when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN);
when(shutdownTask.taskType()).thenReturn(TaskType.SHUTDOWN);
when(shutdownTask.call()).thenAnswer(i -> {
awaitBarrier(taskArriveBarrier);
@@ -578,6 +666,7 @@ public class ShardConsumerTest {
private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(processingState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(processingTask);
when(processingState.requiresDataAvailability()).thenReturn(true);
+ when(processingState.taskType()).thenReturn(TaskType.PROCESS);
when(processingTask.taskType()).thenReturn(TaskType.PROCESS);
when(processingTask.call()).thenAnswer(i -> {
awaitBarrier(taskCallBarrier);
@@ -597,6 +686,7 @@ public class ShardConsumerTest {
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask);
+ when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> {
awaitBarrier(taskCallBarrier);