Add a listener to capture task execution in shardConsumer (#417)

* Add a listener to capture when tasks are executed in the ShardConsumer
This commit is contained in:
akhani18 2018-10-10 13:01:41 -07:00 committed by Justin Pfifer
parent 14c68296f0
commit 2609e1ce46
9 changed files with 268 additions and 18 deletions

View file

@ -582,7 +582,8 @@ public class Scheduler implements Runnable {
aggregatorUtil,
hierarchicalShardSyncer,
metricsFactory);
return new ShardConsumer(cache, executorService, shardInfo, lifecycleConfig.logWarningForTaskAfterMillis(), argument);
return new ShardConsumer(cache, executorService, shardInfo, lifecycleConfig.logWarningForTaskAfterMillis(),
argument, lifecycleConfig.taskExecutionListener());
}
/**

View file

@ -46,4 +46,10 @@ public class LifecycleConfig {
*/
private AggregatorUtil aggregatorUtil = new AggregatorUtil();
/**
* TaskExecutionListener to be used to handle events during task execution lifecycle for a shard.
*
* <p>Default value: {@link NoOpTaskExecutionListener}</p>
*/
private TaskExecutionListener taskExecutionListener = new NoOpTaskExecutionListener();
}

View file

@ -0,0 +1,31 @@
/*
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
/**
* NoOp implementation of {@link TaskExecutionListener} interface that takes no action on task execution.
*/
public class NoOpTaskExecutionListener implements TaskExecutionListener {
@Override
public void beforeTaskExecution(TaskExecutionListenerInput input) {
}
@Override
public void afterTaskExecution(TaskExecutionListenerInput input) {
}
}

View file

@ -40,6 +40,7 @@ import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.exceptions.internal.BlockedOnParentShardException;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
import software.amazon.kinesis.metrics.MetricsCollectingTaskDecorator;
import software.amazon.kinesis.metrics.MetricsFactory;
import software.amazon.kinesis.retrieval.RecordsPublisher;
@ -66,6 +67,7 @@ public class ShardConsumer {
private final Optional<Long> logWarningForTaskAfterMillis;
private final Function<ConsumerTask, ConsumerTask> taskMetricsDecorator;
private final int bufferSize;
private final TaskExecutionListener taskExecutionListener;
private ConsumerTask currentTask;
private TaskOutcome taskOutcome;
@ -95,10 +97,11 @@ public class ShardConsumer {
private final InternalSubscriber subscriber;
public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo,
Optional<Long> logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument) {
Optional<Long> logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument,
TaskExecutionListener taskExecutionListener) {
this(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, shardConsumerArgument,
ConsumerStates.INITIAL_STATE,
ShardConsumer.metricsWrappingFunction(shardConsumerArgument.metricsFactory()), 8);
ShardConsumer.metricsWrappingFunction(shardConsumerArgument.metricsFactory()), 8, taskExecutionListener);
}
//
@ -106,12 +109,14 @@ public class ShardConsumer {
//
public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo,
Optional<Long> logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument,
ConsumerState initialState, Function<ConsumerTask, ConsumerTask> taskMetricsDecorator, int bufferSize) {
ConsumerState initialState, Function<ConsumerTask, ConsumerTask> taskMetricsDecorator,
int bufferSize, TaskExecutionListener taskExecutionListener) {
this.recordsPublisher = recordsPublisher;
this.executorService = executorService;
this.shardInfo = shardInfo;
this.shardConsumerArgument = shardConsumerArgument;
this.logWarningForTaskAfterMillis = logWarningForTaskAfterMillis;
this.taskExecutionListener = taskExecutionListener;
this.currentState = initialState;
this.taskMetricsDecorator = taskMetricsDecorator;
scheduler = Schedulers.from(executorService);
@ -379,6 +384,11 @@ public class ShardConsumer {
}
private synchronized void executeTask(ProcessRecordsInput input) {
TaskExecutionListenerInput taskExecutionListenerInput = TaskExecutionListenerInput.builder()
.shardInfo(shardInfo)
.taskType(currentState.taskType())
.build();
taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput);
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input);
if (task != null) {
taskDispatchedAt = Instant.now();
@ -391,7 +401,9 @@ public class ShardConsumer {
taskIsRunning = false;
}
taskOutcome = resultToOutcome(result);
taskExecutionListenerInput = taskExecutionListenerInput.toBuilder().taskOutcome(taskOutcome).build();
}
taskExecutionListener.afterTaskExecution(taskExecutionListenerInput);
}
private TaskOutcome resultToOutcome(TaskResult result) {
@ -435,10 +447,6 @@ public class ShardConsumer {
return nextState;
}
private enum TaskOutcome {
SUCCESSFUL, END_OF_SHARD, FAILURE
}
private void logTaskException(TaskResult taskResult) {
if (log.isDebugEnabled()) {
Exception taskException = taskResult.getException();

View file

@ -0,0 +1,31 @@
/*
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
/**
* A listener for callbacks on task execution lifecycle for for a shard.
*
* Note: Recommended not to have a blocking implementation since these methods are
* called around the ShardRecordProcessor. A blocking call would result in slowing
* down the ShardConsumer.
*/
public interface TaskExecutionListener {
void beforeTaskExecution(TaskExecutionListenerInput input);
void afterTaskExecution(TaskExecutionListenerInput input);
}

View file

@ -0,0 +1,33 @@
/*
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
/**
* Enumerates types of outcome of tasks executed as part of processing a shard.
*/
public enum TaskOutcome {
/**
* Denotes a successful task outcome.
*/
SUCCESSFUL,
/**
* Denotes that the last record from the shard has been read/consumed.
*/
END_OF_SHARD,
/**
* Denotes a failure or exception during processing of the shard.
*/
FAILURE
}

View file

@ -0,0 +1,48 @@
/*
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package software.amazon.kinesis.lifecycle.events;
import lombok.Builder;
import lombok.Data;
import lombok.experimental.Accessors;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.lifecycle.TaskOutcome;
import software.amazon.kinesis.lifecycle.TaskType;
import software.amazon.kinesis.lifecycle.TaskExecutionListener;
/**
* Container for the parameters to the TaskExecutionListener's
* {@link TaskExecutionListener#beforeTaskExecution(TaskExecutionListenerInput)} method.
* {@link TaskExecutionListener#afterTaskExecution(TaskExecutionListenerInput)} method.
*/
@Data
@Builder(toBuilder = true)
@Accessors(fluent = true)
public class TaskExecutionListenerInput {
/**
* Detailed information about the shard whose progress is monitored by TaskExecutionListener.
*/
private final ShardInfo shardInfo;
/**
* The type of task being executed for the shard.
*
* This corresponds to the state the shard is in.
*/
private final TaskType taskType;
/**
* Outcome of the task execution for the shard.
*/
private final TaskOutcome taskOutcome;
}

View file

@ -92,6 +92,8 @@ public class ConsumerStatesTest {
private MetricsFactory metricsFactory;
@Mock
private ProcessRecordsInput processRecordsInput;
@Mock
private TaskExecutionListener taskExecutionListener;
private long parentShardPollIntervalMillis = 0xCAFE;
private boolean cleanupLeasesOfCompletedShards = true;
@ -115,7 +117,7 @@ public class ConsumerStatesTest {
cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, shardDetector, new AggregatorUtil(),
hierarchicalShardSyncer, metricsFactory);
consumer = spy(
new ShardConsumer(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, argument));
new ShardConsumer(recordsPublisher, executorService, shardInfo, logWarningForTaskAfterMillis, argument, taskExecutionListener));
when(shardInfo.shardId()).thenReturn("shardId-000000000000");
when(recordProcessorCheckpointer.checkpointer()).thenReturn(checkpointer);

View file

@ -30,6 +30,8 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import java.time.Instant;
@ -66,6 +68,7 @@ import lombok.extern.slf4j.Slf4j;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
@ -79,6 +82,11 @@ public class ShardConsumerTest {
private final String shardId = "shardId-0-0";
private final String concurrencyToken = "TestToken";
private ShardInfo shardInfo;
private TaskExecutionListenerInput initialTaskInput;
private TaskExecutionListenerInput processTaskInput;
private TaskExecutionListenerInput shutdownTaskInput;
private TaskExecutionListenerInput shutdownRequestedTaskInput;
private TaskExecutionListenerInput shutdownRequestedAwaitTaskInput;
private ExecutorService executorService;
@Mock
@ -111,6 +119,8 @@ public class ShardConsumerTest {
private ConsumerTask shutdownRequestedTask;
@Mock
private ConsumerState shutdownRequestedAwaitState;
@Mock
private TaskExecutionListener taskExecutionListener;
private ProcessRecordsInput processRecordsInput;
@ -128,6 +138,16 @@ public class ShardConsumerTest {
processRecordsInput = ProcessRecordsInput.builder().isAtShardEnd(false).cacheEntryTime(Instant.now())
.millisBehindLatest(1000L).records(Collections.emptyList()).build();
initialTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
.taskType(TaskType.INITIALIZE).build();
processTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
.taskType(TaskType.PROCESS).build();
shutdownRequestedTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
.taskType(TaskType.SHUTDOWN_NOTIFICATION).build();
shutdownRequestedAwaitTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
.taskType(TaskType.SHUTDOWN_COMPLETE).build();
shutdownTaskInput = TaskExecutionListenerInput.builder().shardInfo(shardInfo)
.taskType(TaskType.SHUTDOWN).build();
}
@After
@ -219,7 +239,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
shardConsumerArgument, initialState, Function.identity(), 1);
shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
boolean initComplete = false;
while (!initComplete) {
@ -246,9 +266,22 @@ public class ShardConsumerTest {
verify(cache.subscription, times(3)).request(anyLong());
verify(cache.subscription).cancel();
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any());
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput);
verifyNoMoreInteractions(taskExecutionListener);
}
@Test
public void testDataArrivesAfterProcessing2() throws Exception {
@ -262,7 +295,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
shardConsumerArgument, initialState, Function.identity(), 1);
shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
boolean initComplete = false;
while (!initComplete) {
@ -302,6 +335,18 @@ public class ShardConsumerTest {
verify(processingTask, times(3)).call();
verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(3)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(3)).afterTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput);
verifyNoMoreInteractions(taskExecutionListener);
}
@SuppressWarnings("unchecked")
@ -309,7 +354,7 @@ public class ShardConsumerTest {
@Ignore
public final void testInitializationStateUponFailure() throws Exception {
ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1);
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any())).thenReturn(initializeTask);
when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad")));
@ -342,7 +387,7 @@ public class ShardConsumerTest {
ExecutorService failingService = mock(ExecutorService.class);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1);
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener);
doThrow(new RejectedExecutionException()).when(failingService).execute(any());
@ -350,15 +395,16 @@ public class ShardConsumerTest {
do {
initComplete = consumer.initializeComplete().get();
} while (!initComplete);
verifyZeroInteractions(taskExecutionListener);
}
@Test
public void testErrorThrowableInInitialization() throws Exception {
ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1);
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener);
when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> {
throw new Error("Error");
});
@ -368,6 +414,8 @@ public class ShardConsumerTest {
} catch (ExecutionException ee) {
assertThat(ee.getCause(), instanceOf(Error.class));
}
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verifyNoMoreInteractions(taskExecutionListener);
}
@Test
@ -377,7 +425,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
shardConsumerArgument, initialState, t -> t, 1);
shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener);
mockSuccessfulInitialize(null);
@ -386,6 +434,7 @@ public class ShardConsumerTest {
when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState);
when(shutdownRequestedState.requiresDataAvailability()).thenReturn(false);
when(shutdownRequestedState.createTask(any(), any(), any())).thenReturn(shutdownRequestedTask);
when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION);
when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null));
when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED)))
@ -396,6 +445,7 @@ public class ShardConsumerTest {
when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.REQUESTED)))
.thenReturn(shutdownRequestedState);
when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST))).thenReturn(shutdownState);
when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE);
mockSuccessfulShutdown(null);
@ -433,7 +483,24 @@ public class ShardConsumerTest {
verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED));
verify(shutdownRequestedAwaitState).createTask(any(), any(), any());
verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownRequestedTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownRequestedAwaitTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
shutdownRequestedTaskInput = shutdownRequestedTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
// No task is created/run for this shutdownRequestedAwaitState, so there's no task outcome.
verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownRequestedTaskInput);
verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownRequestedAwaitTaskInput);
verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput);
verifyNoMoreInteractions(taskExecutionListener);
}
@Test
@ -441,7 +508,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L),
shardConsumerArgument, initialState, Function.identity(), 1);
shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
mockSuccessfulInitialize(null);
mockSuccessfulProcessing(null);
@ -473,6 +540,13 @@ public class ShardConsumerTest {
assertThat(healthCheckOutcome, equalTo(expectedException));
verify(cache.subscription, times(2)).request(anyLong());
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput);
initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
verifyNoMoreInteractions(taskExecutionListener);
}
@Test
@ -481,7 +555,7 @@ public class ShardConsumerTest {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L),
shardConsumerArgument, initialState, Function.identity(), 1);
shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener);
CyclicBarrier taskArriveBarrier = new CyclicBarrier(2);
CyclicBarrier taskDepartBarrier = new CyclicBarrier(2);
@ -551,6 +625,19 @@ public class ShardConsumerTest {
assertThat(consumer.taskRunningTime(), nullValue());
consumer.healthCheck();
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
initialTaskInput = initialTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
processTaskInput = processTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
shutdownTaskInput = shutdownTaskInput.toBuilder().taskOutcome(TaskOutcome.SUCCESSFUL).build();
verify(taskExecutionListener, times(1)).afterTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).afterTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).afterTaskExecution(shutdownTaskInput);
verifyNoMoreInteractions(taskExecutionListener);
}
private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) {
@ -559,6 +646,7 @@ public class ShardConsumerTest {
private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) {
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(shutdownTask);
when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN);
when(shutdownTask.taskType()).thenReturn(TaskType.SHUTDOWN);
when(shutdownTask.call()).thenAnswer(i -> {
awaitBarrier(taskArriveBarrier);
@ -578,6 +666,7 @@ public class ShardConsumerTest {
private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(processingState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(processingTask);
when(processingState.requiresDataAvailability()).thenReturn(true);
when(processingState.taskType()).thenReturn(TaskType.PROCESS);
when(processingTask.taskType()).thenReturn(TaskType.PROCESS);
when(processingTask.call()).thenAnswer(i -> {
awaitBarrier(taskCallBarrier);
@ -597,6 +686,7 @@ public class ShardConsumerTest {
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> {
awaitBarrier(taskCallBarrier);