Adding functionality to override the vanilla KCL tasks (#1440)

This commit is contained in:
Abhi Gupta 2025-02-21 09:56:21 +05:30 committed by GitHub
parent 68a7a9bf53
commit 8deebe4bda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 278 additions and 93 deletions

View file

@ -87,6 +87,8 @@ import software.amazon.kinesis.leases.dynamodb.DynamoDBMultiStreamLeaseSerialize
import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
import software.amazon.kinesis.lifecycle.ConsumerTaskFactory;
import software.amazon.kinesis.lifecycle.KinesisConsumerTaskFactory;
import software.amazon.kinesis.lifecycle.LifecycleConfig;
import software.amazon.kinesis.lifecycle.ShardConsumer;
import software.amazon.kinesis.lifecycle.ShardConsumerArgument;
@ -188,6 +190,7 @@ public class Scheduler implements Runnable {
private final SchemaRegistryDecoder schemaRegistryDecoder;
private final DeletedStreamListProvider deletedStreamListProvider;
private final ConsumerTaskFactory taskFactory;
@Getter(AccessLevel.NONE)
private final MigrationStateMachine migrationStateMachine;
@ -264,6 +267,33 @@ public class Scheduler implements Runnable {
@NonNull final ProcessorConfig processorConfig,
@NonNull final RetrievalConfig retrievalConfig,
@NonNull final DiagnosticEventFactory diagnosticEventFactory) {
this(
checkpointConfig,
coordinatorConfig,
leaseManagementConfig,
lifecycleConfig,
metricsConfig,
processorConfig,
retrievalConfig,
diagnosticEventFactory,
new KinesisConsumerTaskFactory());
}
/**
* Customers do not currently have the ability to customize the DiagnosticEventFactory, but this visibility
* is desired for testing. This constructor is only used for testing to provide a mock DiagnosticEventFactory.
*/
@VisibleForTesting
protected Scheduler(
@NonNull final CheckpointConfig checkpointConfig,
@NonNull final CoordinatorConfig coordinatorConfig,
@NonNull final LeaseManagementConfig leaseManagementConfig,
@NonNull final LifecycleConfig lifecycleConfig,
@NonNull final MetricsConfig metricsConfig,
@NonNull final ProcessorConfig processorConfig,
@NonNull final RetrievalConfig retrievalConfig,
@NonNull final DiagnosticEventFactory diagnosticEventFactory,
@NonNull final ConsumerTaskFactory taskFactory) {
this.checkpointConfig = checkpointConfig;
this.coordinatorConfig = coordinatorConfig;
this.leaseManagementConfig = leaseManagementConfig;
@ -371,6 +401,7 @@ public class Scheduler implements Runnable {
this.schemaRegistryDecoder = this.retrievalConfig.glueSchemaRegistryDeserializer() == null
? null
: new SchemaRegistryDecoder(this.retrievalConfig.glueSchemaRegistryDeserializer());
this.taskFactory = taskFactory;
}
/**

View file

@ -34,13 +34,19 @@ interface ConsumerState {
* the consumer to use build the task, or execute state.
* @param input
* the process input received, this may be null if it's a control message
* @param taskFactory
* a factory for creating tasks
* @return a valid task for this state or null if there is no task required.
*/
ConsumerTask createTask(ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input);
ConsumerTask createTask(
ShardConsumerArgument consumerArgument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory);
/**
* Provides the next state of the consumer upon success of the task return by
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput)}.
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput, ConsumerTaskFactory)}.
*
* @return the next state that the consumer should transition to, this may be the same object as the current
* state.

View file

@ -17,7 +17,6 @@ package software.amazon.kinesis.lifecycle;
import lombok.Getter;
import lombok.experimental.Accessors;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.ThrottlingReporter;
/**
* Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks,
@ -121,11 +120,11 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input) {
return new BlockOnParentShardTask(
consumerArgument.shardInfo(),
consumerArgument.leaseCoordinator().leaseRefresher(),
consumerArgument.parentShardPollIntervalMillis());
ShardConsumerArgument consumerArgument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createBlockOnParentTask(consumerArgument);
}
@Override
@ -187,16 +186,11 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
return new InitializeTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.checkpoint(),
argument.recordProcessorCheckpointer(),
argument.initialPositionInStream(),
argument.recordsPublisher(),
argument.taskBackoffTimeMillis(),
argument.metricsFactory());
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createInitializeTask(argument);
}
@Override
@ -250,24 +244,11 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ThrottlingReporter throttlingReporter =
new ThrottlingReporter(5, argument.shardInfo().shardId());
return new ProcessTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
argument.taskBackoffTimeMillis(),
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
argument.shardDetector(),
throttlingReporter,
input,
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
argument.idleTimeInMilliseconds(),
argument.aggregatorUtil(),
argument.metricsFactory(),
argument.schemaRegistryDecoder(),
argument.leaseCoordinator().leaseStatsRecorder());
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createProcessTask(argument, input);
}
@Override
@ -331,14 +312,12 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
// TODO: notify shutdownrequested
return new ShutdownNotificationTask(
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownNotification(),
argument.shardInfo(),
consumer.shardConsumerArgument().leaseCoordinator());
return taskFactory.createShutdownNotificationTask(argument, consumer);
}
@Override
@ -405,7 +384,10 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return null;
}
@ -483,25 +465,12 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
// TODO: set shutdown reason
return new ShutdownTask(
argument.shardInfo(),
argument.shardDetector(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory(),
input == null ? null : input.childShards(),
argument.streamIdentifier(),
argument.leaseCleanupManager());
return taskFactory.createShutdownTask(argument, consumer, input);
}
@Override
@ -569,7 +538,10 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return null;
}

View file

@ -0,0 +1,47 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
@KinesisClientInternalApi
public interface ConsumerTaskFactory {
/**
* Creates a shutdown task.
*/
ConsumerTask createShutdownTask(ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input);
/**
* Creates a process task.
*/
ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput);
/**
* Creates an initialize task.
*/
ConsumerTask createInitializeTask(ShardConsumerArgument argument);
/**
* Creates a block on parent task.
*/
ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument);
/**
* Creates a shutdown notification task.
*/
ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer);
}

View file

@ -0,0 +1,98 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.ThrottlingReporter;
@KinesisClientInternalApi
public class KinesisConsumerTaskFactory implements ConsumerTaskFactory {
@Override
public ConsumerTask createShutdownTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
return new ShutdownTask(
argument.shardInfo(),
argument.shardDetector(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory(),
input == null ? null : input.childShards(),
argument.streamIdentifier(),
argument.leaseCleanupManager());
}
@Override
public ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput) {
ThrottlingReporter throttlingReporter =
new ThrottlingReporter(5, argument.shardInfo().shardId());
return new ProcessTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
argument.taskBackoffTimeMillis(),
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
argument.shardDetector(),
throttlingReporter,
processRecordsInput,
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
argument.idleTimeInMilliseconds(),
argument.aggregatorUtil(),
argument.metricsFactory(),
argument.schemaRegistryDecoder(),
argument.leaseCoordinator().leaseStatsRecorder());
}
@Override
public ConsumerTask createInitializeTask(ShardConsumerArgument argument) {
return new InitializeTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.checkpoint(),
argument.recordProcessorCheckpointer(),
argument.initialPositionInStream(),
argument.recordsPublisher(),
argument.taskBackoffTimeMillis(),
argument.metricsFactory());
}
@Override
public ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument) {
return new BlockOnParentShardTask(
argument.shardInfo(),
argument.leaseCoordinator().leaseRefresher(),
argument.parentShardPollIntervalMillis());
}
@Override
public ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer) {
return new ShutdownNotificationTask(
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownNotification(),
argument.shardInfo(),
consumer.shardConsumerArgument().leaseCoordinator());
}
}

View file

@ -86,6 +86,8 @@ public class ShardConsumer {
private ProcessRecordsInput shardEndProcessRecordsInput;
private final ConsumerTaskFactory taskFactory;
public ShardConsumer(
RecordsPublisher recordsPublisher,
ExecutorService executorService,
@ -103,7 +105,8 @@ public class ShardConsumer {
ConsumerStates.INITIAL_STATE,
8,
taskExecutionListener,
readTimeoutsToIgnoreBeforeWarning);
readTimeoutsToIgnoreBeforeWarning,
new KinesisConsumerTaskFactory());
}
//
@ -119,6 +122,30 @@ public class ShardConsumer {
int bufferSize,
TaskExecutionListener taskExecutionListener,
int readTimeoutsToIgnoreBeforeWarning) {
this(
recordsPublisher,
executorService,
shardInfo,
logWarningForTaskAfterMillis,
shardConsumerArgument,
initialState,
bufferSize,
taskExecutionListener,
readTimeoutsToIgnoreBeforeWarning,
new KinesisConsumerTaskFactory());
}
public ShardConsumer(
RecordsPublisher recordsPublisher,
ExecutorService executorService,
ShardInfo shardInfo,
Optional<Long> logWarningForTaskAfterMillis,
ShardConsumerArgument shardConsumerArgument,
ConsumerState initialState,
int bufferSize,
TaskExecutionListener taskExecutionListener,
int readTimeoutsToIgnoreBeforeWarning,
ConsumerTaskFactory taskFactory) {
this.recordsPublisher = recordsPublisher;
this.executorService = executorService;
this.shardInfo = shardInfo;
@ -134,6 +161,7 @@ public class ShardConsumer {
if (this.shardInfo.isCompleted()) {
markForShutdown(ShutdownReason.SHARD_END);
}
this.taskFactory = taskFactory;
}
synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) {
@ -345,7 +373,7 @@ public class ShardConsumer {
.taskType(currentState.taskType())
.build();
taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput);
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input);
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input, taskFactory);
if (task != null) {
taskDispatchedAt = Instant.now();
currentTask = task;

View file

@ -127,6 +127,7 @@ public class ConsumerStatesTest {
private long idleTimeInMillis = 1000L;
private Optional<Long> logWarningForTaskAfterMillis = Optional.empty();
private SchemaRegistryDecoder schemaRegistryDecoder = null;
private final ConsumerTaskFactory taskFactory = new KinesisConsumerTaskFactory();
@Before
public void setup() {
@ -177,7 +178,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState();
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, taskWith(BlockOnParentShardTask.class, ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(
@ -209,7 +210,7 @@ public class ConsumerStatesTest {
@Test
public void initializingStateTest() {
ConsumerState state = ShardConsumerState.INITIALIZING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, initTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, initTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -242,7 +243,7 @@ public class ConsumerStatesTest {
public void processingStateTestSynchronous() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -274,7 +275,7 @@ public class ConsumerStatesTest {
public void processingStateTestAsynchronous() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -306,7 +307,7 @@ public class ConsumerStatesTest {
public void processingStateRecordsFetcher() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -339,7 +340,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.consumerState();
consumer.gracefulShutdown(shutdownNotification);
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(
task,
@ -373,7 +374,7 @@ public class ConsumerStatesTest {
public void shutdownRequestCompleteStateTest() {
ConsumerState state = ConsumerStates.SHUTDOWN_REQUEST_COMPLETION_STATE;
assertThat(state.createTask(argument, consumer, null), nullValue());
assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue());
assertThat(state.successTransition(), equalTo(state));
@ -409,7 +410,7 @@ public class ConsumerStatesTest {
childShards.add(leftChild);
childShards.add(rightChild);
when(processRecordsInput.childShards()).thenReturn(childShards);
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput);
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput, taskFactory);
assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(
@ -443,7 +444,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState();
assertThat(state.createTask(argument, consumer, null), nullValue());
assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue());
assertThat(state.successTransition(), equalTo(state));
for (ShutdownReason reason : ShutdownReason.values()) {

View file

@ -332,7 +332,7 @@ public class ShardConsumerTest {
verify(cache.subscription, times(3)).request(anyLong());
verify(cache.subscription).cancel();
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any());
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
@ -394,7 +394,7 @@ public class ShardConsumerTest {
verify(cache.subscription, times(1)).request(anyLong());
verify(cache.subscription).cancel();
verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any());
verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
@ -437,14 +437,14 @@ public class ShardConsumerTest {
cache.publish();
awaitAndResetBarrier(taskCallBarrier);
verify(processingState).createTask(any(), any(), any());
verify(processingState).createTask(any(), any(), any(), any());
verify(processingTask).call();
cache.awaitRequest();
cache.publish();
awaitAndResetBarrier(taskCallBarrier);
verify(processingState, times(2)).createTask(any(), any(), any());
verify(processingState, times(2)).createTask(any(), any(), any(), any());
verify(processingTask, times(2)).call();
cache.awaitRequest();
@ -460,7 +460,7 @@ public class ShardConsumerTest {
shutdownComplete = consumer.shutdownComplete().get();
} while (!shutdownComplete);
verify(processingState, times(3)).createTask(any(), any(), any());
verify(processingState, times(3)).createTask(any(), any(), any(), any());
verify(processingTask, times(3)).call();
verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
@ -487,7 +487,7 @@ public class ShardConsumerTest {
public final void testInitializationStateUponFailure() throws Exception {
final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any()))
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any(), any()))
.thenReturn(initializeTask);
when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad")));
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
@ -505,7 +505,7 @@ public class ShardConsumerTest {
awaitAndResetBarrier(taskBarrier);
}
verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any());
verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(initialState, never()).successTransition();
verify(initialState, never()).shutdownTransition(any());
}
@ -665,7 +665,7 @@ public class ShardConsumerTest {
public void testErrorThrowableInInitialization() throws Exception {
final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask);
when(initialState.createTask(any(), any(), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> {
throw new Error("Error");
@ -692,13 +692,13 @@ public class ShardConsumerTest {
mockSuccessfulProcessing(taskBarrier);
when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState);
when(shutdownRequestedState.createTask(any(), any(), any())).thenReturn(shutdownRequestedTask);
when(shutdownRequestedState.createTask(any(), any(), any(), any())).thenReturn(shutdownRequestedTask);
when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION);
when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null));
when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED)))
.thenReturn(shutdownRequestedAwaitState);
when(shutdownRequestedAwaitState.createTask(any(), any(), any())).thenReturn(null);
when(shutdownRequestedAwaitState.createTask(any(), any(), any(), any())).thenReturn(null);
when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST)))
.thenReturn(shutdownState);
when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE);
@ -733,11 +733,11 @@ public class ShardConsumerTest {
shutdownComplete = consumer.shutdownComplete().get();
assertTrue(shutdownComplete);
verify(processingState, times(2)).createTask(any(), any(), any());
verify(processingState, times(2)).createTask(any(), any(), any(), any());
verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(shutdownRequestedState).createTask(any(), any(), any());
verify(shutdownRequestedState).createTask(any(), any(), any(), any());
verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED));
verify(shutdownRequestedAwaitState).createTask(any(), any(), any());
verify(shutdownRequestedAwaitState).createTask(any(), any(), any(), any());
verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
@ -948,7 +948,7 @@ public class ShardConsumerTest {
when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
final ConsumerTask mockTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockTask);
when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockTask);
// Simulate successful BlockedOnParent task execution
// and successful Initialize task execution
when(mockTask.call()).thenReturn(new TaskResult(false));
@ -993,7 +993,7 @@ public class ShardConsumerTest {
reset(mockState);
when(mockState.taskType()).thenReturn(TaskType.PROCESS);
final ConsumerTask mockProcessTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask);
when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockProcessTask);
when(mockProcessTask.call()).then(input -> {
// first we want to wait for subscribe to be called,
// but we cannot control the timing, so wait for 10 seconds
@ -1045,7 +1045,8 @@ public class ShardConsumerTest {
}
private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) {
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(shutdownTask);
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(shutdownTask);
when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN);
when(shutdownTask.call()).thenAnswer(i -> {
awaitBarrier(taskArriveBarrier);
@ -1063,7 +1064,7 @@ public class ShardConsumerTest {
}
private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(processingState.createTask(eq(shardConsumerArgument), any(), any()))
when(processingState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(processingTask);
when(processingState.taskType()).thenReturn(TaskType.PROCESS);
when(processingTask.taskType()).thenReturn(TaskType.PROCESS);
@ -1088,7 +1089,8 @@ public class ShardConsumerTest {
}
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask);
when(initialState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> {
@ -1107,7 +1109,7 @@ public class ShardConsumerTest {
}
private void mockSuccessfulUnblockOnParents() {
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any()))
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(blockedOnParentsTask);
when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult);