Adding functionality to override the vanilla KCL tasks (#1440)
This commit is contained in:
parent
68a7a9bf53
commit
8deebe4bda
8 changed files with 278 additions and 93 deletions
|
|
@ -87,6 +87,8 @@ import software.amazon.kinesis.leases.dynamodb.DynamoDBMultiStreamLeaseSerialize
|
|||
import software.amazon.kinesis.leases.exceptions.DependencyException;
|
||||
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
|
||||
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
|
||||
import software.amazon.kinesis.lifecycle.ConsumerTaskFactory;
|
||||
import software.amazon.kinesis.lifecycle.KinesisConsumerTaskFactory;
|
||||
import software.amazon.kinesis.lifecycle.LifecycleConfig;
|
||||
import software.amazon.kinesis.lifecycle.ShardConsumer;
|
||||
import software.amazon.kinesis.lifecycle.ShardConsumerArgument;
|
||||
|
|
@ -188,6 +190,7 @@ public class Scheduler implements Runnable {
|
|||
private final SchemaRegistryDecoder schemaRegistryDecoder;
|
||||
|
||||
private final DeletedStreamListProvider deletedStreamListProvider;
|
||||
private final ConsumerTaskFactory taskFactory;
|
||||
|
||||
@Getter(AccessLevel.NONE)
|
||||
private final MigrationStateMachine migrationStateMachine;
|
||||
|
|
@ -264,6 +267,33 @@ public class Scheduler implements Runnable {
|
|||
@NonNull final ProcessorConfig processorConfig,
|
||||
@NonNull final RetrievalConfig retrievalConfig,
|
||||
@NonNull final DiagnosticEventFactory diagnosticEventFactory) {
|
||||
this(
|
||||
checkpointConfig,
|
||||
coordinatorConfig,
|
||||
leaseManagementConfig,
|
||||
lifecycleConfig,
|
||||
metricsConfig,
|
||||
processorConfig,
|
||||
retrievalConfig,
|
||||
diagnosticEventFactory,
|
||||
new KinesisConsumerTaskFactory());
|
||||
}
|
||||
|
||||
/**
|
||||
* Customers do not currently have the ability to customize the DiagnosticEventFactory, but this visibility
|
||||
* is desired for testing. This constructor is only used for testing to provide a mock DiagnosticEventFactory.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
protected Scheduler(
|
||||
@NonNull final CheckpointConfig checkpointConfig,
|
||||
@NonNull final CoordinatorConfig coordinatorConfig,
|
||||
@NonNull final LeaseManagementConfig leaseManagementConfig,
|
||||
@NonNull final LifecycleConfig lifecycleConfig,
|
||||
@NonNull final MetricsConfig metricsConfig,
|
||||
@NonNull final ProcessorConfig processorConfig,
|
||||
@NonNull final RetrievalConfig retrievalConfig,
|
||||
@NonNull final DiagnosticEventFactory diagnosticEventFactory,
|
||||
@NonNull final ConsumerTaskFactory taskFactory) {
|
||||
this.checkpointConfig = checkpointConfig;
|
||||
this.coordinatorConfig = coordinatorConfig;
|
||||
this.leaseManagementConfig = leaseManagementConfig;
|
||||
|
|
@ -371,6 +401,7 @@ public class Scheduler implements Runnable {
|
|||
this.schemaRegistryDecoder = this.retrievalConfig.glueSchemaRegistryDeserializer() == null
|
||||
? null
|
||||
: new SchemaRegistryDecoder(this.retrievalConfig.glueSchemaRegistryDeserializer());
|
||||
this.taskFactory = taskFactory;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -34,13 +34,19 @@ interface ConsumerState {
|
|||
* the consumer to use build the task, or execute state.
|
||||
* @param input
|
||||
* the process input received, this may be null if it's a control message
|
||||
* @param taskFactory
|
||||
* a factory for creating tasks
|
||||
* @return a valid task for this state or null if there is no task required.
|
||||
*/
|
||||
ConsumerTask createTask(ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input);
|
||||
ConsumerTask createTask(
|
||||
ShardConsumerArgument consumerArgument,
|
||||
ShardConsumer consumer,
|
||||
ProcessRecordsInput input,
|
||||
ConsumerTaskFactory taskFactory);
|
||||
|
||||
/**
|
||||
* Provides the next state of the consumer upon success of the task return by
|
||||
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput)}.
|
||||
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput, ConsumerTaskFactory)}.
|
||||
*
|
||||
* @return the next state that the consumer should transition to, this may be the same object as the current
|
||||
* state.
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ package software.amazon.kinesis.lifecycle;
|
|||
import lombok.Getter;
|
||||
import lombok.experimental.Accessors;
|
||||
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
||||
import software.amazon.kinesis.retrieval.ThrottlingReporter;
|
||||
|
||||
/**
|
||||
* Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks,
|
||||
|
|
@ -121,11 +120,11 @@ class ConsumerStates {
|
|||
|
||||
@Override
|
||||
public ConsumerTask createTask(
|
||||
ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input) {
|
||||
return new BlockOnParentShardTask(
|
||||
consumerArgument.shardInfo(),
|
||||
consumerArgument.leaseCoordinator().leaseRefresher(),
|
||||
consumerArgument.parentShardPollIntervalMillis());
|
||||
ShardConsumerArgument consumerArgument,
|
||||
ShardConsumer consumer,
|
||||
ProcessRecordsInput input,
|
||||
ConsumerTaskFactory taskFactory) {
|
||||
return taskFactory.createBlockOnParentTask(consumerArgument);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -187,16 +186,11 @@ class ConsumerStates {
|
|||
|
||||
@Override
|
||||
public ConsumerTask createTask(
|
||||
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
|
||||
return new InitializeTask(
|
||||
argument.shardInfo(),
|
||||
argument.shardRecordProcessor(),
|
||||
argument.checkpoint(),
|
||||
argument.recordProcessorCheckpointer(),
|
||||
argument.initialPositionInStream(),
|
||||
argument.recordsPublisher(),
|
||||
argument.taskBackoffTimeMillis(),
|
||||
argument.metricsFactory());
|
||||
ShardConsumerArgument argument,
|
||||
ShardConsumer consumer,
|
||||
ProcessRecordsInput input,
|
||||
ConsumerTaskFactory taskFactory) {
|
||||
return taskFactory.createInitializeTask(argument);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -250,24 +244,11 @@ class ConsumerStates {
|
|||
|
||||
@Override
|
||||
public ConsumerTask createTask(
|
||||
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
|
||||
ThrottlingReporter throttlingReporter =
|
||||
new ThrottlingReporter(5, argument.shardInfo().shardId());
|
||||
return new ProcessTask(
|
||||
argument.shardInfo(),
|
||||
argument.shardRecordProcessor(),
|
||||
argument.recordProcessorCheckpointer(),
|
||||
argument.taskBackoffTimeMillis(),
|
||||
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
|
||||
argument.shardDetector(),
|
||||
throttlingReporter,
|
||||
input,
|
||||
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
|
||||
argument.idleTimeInMilliseconds(),
|
||||
argument.aggregatorUtil(),
|
||||
argument.metricsFactory(),
|
||||
argument.schemaRegistryDecoder(),
|
||||
argument.leaseCoordinator().leaseStatsRecorder());
|
||||
ShardConsumerArgument argument,
|
||||
ShardConsumer consumer,
|
||||
ProcessRecordsInput input,
|
||||
ConsumerTaskFactory taskFactory) {
|
||||
return taskFactory.createProcessTask(argument, input);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -331,14 +312,12 @@ class ConsumerStates {
|
|||
|
||||
@Override
|
||||
public ConsumerTask createTask(
|
||||
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
|
||||
ShardConsumerArgument argument,
|
||||
ShardConsumer consumer,
|
||||
ProcessRecordsInput input,
|
||||
ConsumerTaskFactory taskFactory) {
|
||||
// TODO: notify shutdownrequested
|
||||
return new ShutdownNotificationTask(
|
||||
argument.shardRecordProcessor(),
|
||||
argument.recordProcessorCheckpointer(),
|
||||
consumer.shutdownNotification(),
|
||||
argument.shardInfo(),
|
||||
consumer.shardConsumerArgument().leaseCoordinator());
|
||||
return taskFactory.createShutdownNotificationTask(argument, consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -405,7 +384,10 @@ class ConsumerStates {
|
|||
|
||||
@Override
|
||||
public ConsumerTask createTask(
|
||||
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
|
||||
ShardConsumerArgument argument,
|
||||
ShardConsumer consumer,
|
||||
ProcessRecordsInput input,
|
||||
ConsumerTaskFactory taskFactory) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
@ -483,25 +465,12 @@ class ConsumerStates {
|
|||
|
||||
@Override
|
||||
public ConsumerTask createTask(
|
||||
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
|
||||
ShardConsumerArgument argument,
|
||||
ShardConsumer consumer,
|
||||
ProcessRecordsInput input,
|
||||
ConsumerTaskFactory taskFactory) {
|
||||
// TODO: set shutdown reason
|
||||
return new ShutdownTask(
|
||||
argument.shardInfo(),
|
||||
argument.shardDetector(),
|
||||
argument.shardRecordProcessor(),
|
||||
argument.recordProcessorCheckpointer(),
|
||||
consumer.shutdownReason(),
|
||||
argument.initialPositionInStream(),
|
||||
argument.cleanupLeasesOfCompletedShards(),
|
||||
argument.ignoreUnexpectedChildShards(),
|
||||
argument.leaseCoordinator(),
|
||||
argument.taskBackoffTimeMillis(),
|
||||
argument.recordsPublisher(),
|
||||
argument.hierarchicalShardSyncer(),
|
||||
argument.metricsFactory(),
|
||||
input == null ? null : input.childShards(),
|
||||
argument.streamIdentifier(),
|
||||
argument.leaseCleanupManager());
|
||||
return taskFactory.createShutdownTask(argument, consumer, input);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -569,7 +538,10 @@ class ConsumerStates {
|
|||
|
||||
@Override
|
||||
public ConsumerTask createTask(
|
||||
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
|
||||
ShardConsumerArgument argument,
|
||||
ShardConsumer consumer,
|
||||
ProcessRecordsInput input,
|
||||
ConsumerTaskFactory taskFactory) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright 2019 Amazon.com, Inc. or its affiliates.
|
||||
* Licensed under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package software.amazon.kinesis.lifecycle;
|
||||
|
||||
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
|
||||
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
||||
|
||||
@KinesisClientInternalApi
|
||||
public interface ConsumerTaskFactory {
|
||||
/**
|
||||
* Creates a shutdown task.
|
||||
*/
|
||||
ConsumerTask createShutdownTask(ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input);
|
||||
|
||||
/**
|
||||
* Creates a process task.
|
||||
*/
|
||||
ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput);
|
||||
|
||||
/**
|
||||
* Creates an initialize task.
|
||||
*/
|
||||
ConsumerTask createInitializeTask(ShardConsumerArgument argument);
|
||||
|
||||
/**
|
||||
* Creates a block on parent task.
|
||||
*/
|
||||
ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument);
|
||||
|
||||
/**
|
||||
* Creates a shutdown notification task.
|
||||
*/
|
||||
ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer);
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright 2019 Amazon.com, Inc. or its affiliates.
|
||||
* Licensed under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package software.amazon.kinesis.lifecycle;
|
||||
|
||||
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
|
||||
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
||||
import software.amazon.kinesis.retrieval.ThrottlingReporter;
|
||||
|
||||
@KinesisClientInternalApi
|
||||
public class KinesisConsumerTaskFactory implements ConsumerTaskFactory {
|
||||
|
||||
@Override
|
||||
public ConsumerTask createShutdownTask(
|
||||
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
|
||||
return new ShutdownTask(
|
||||
argument.shardInfo(),
|
||||
argument.shardDetector(),
|
||||
argument.shardRecordProcessor(),
|
||||
argument.recordProcessorCheckpointer(),
|
||||
consumer.shutdownReason(),
|
||||
argument.initialPositionInStream(),
|
||||
argument.cleanupLeasesOfCompletedShards(),
|
||||
argument.ignoreUnexpectedChildShards(),
|
||||
argument.leaseCoordinator(),
|
||||
argument.taskBackoffTimeMillis(),
|
||||
argument.recordsPublisher(),
|
||||
argument.hierarchicalShardSyncer(),
|
||||
argument.metricsFactory(),
|
||||
input == null ? null : input.childShards(),
|
||||
argument.streamIdentifier(),
|
||||
argument.leaseCleanupManager());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput) {
|
||||
ThrottlingReporter throttlingReporter =
|
||||
new ThrottlingReporter(5, argument.shardInfo().shardId());
|
||||
return new ProcessTask(
|
||||
argument.shardInfo(),
|
||||
argument.shardRecordProcessor(),
|
||||
argument.recordProcessorCheckpointer(),
|
||||
argument.taskBackoffTimeMillis(),
|
||||
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
|
||||
argument.shardDetector(),
|
||||
throttlingReporter,
|
||||
processRecordsInput,
|
||||
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
|
||||
argument.idleTimeInMilliseconds(),
|
||||
argument.aggregatorUtil(),
|
||||
argument.metricsFactory(),
|
||||
argument.schemaRegistryDecoder(),
|
||||
argument.leaseCoordinator().leaseStatsRecorder());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConsumerTask createInitializeTask(ShardConsumerArgument argument) {
|
||||
return new InitializeTask(
|
||||
argument.shardInfo(),
|
||||
argument.shardRecordProcessor(),
|
||||
argument.checkpoint(),
|
||||
argument.recordProcessorCheckpointer(),
|
||||
argument.initialPositionInStream(),
|
||||
argument.recordsPublisher(),
|
||||
argument.taskBackoffTimeMillis(),
|
||||
argument.metricsFactory());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument) {
|
||||
return new BlockOnParentShardTask(
|
||||
argument.shardInfo(),
|
||||
argument.leaseCoordinator().leaseRefresher(),
|
||||
argument.parentShardPollIntervalMillis());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer) {
|
||||
return new ShutdownNotificationTask(
|
||||
argument.shardRecordProcessor(),
|
||||
argument.recordProcessorCheckpointer(),
|
||||
consumer.shutdownNotification(),
|
||||
argument.shardInfo(),
|
||||
consumer.shardConsumerArgument().leaseCoordinator());
|
||||
}
|
||||
}
|
||||
|
|
@ -86,6 +86,8 @@ public class ShardConsumer {
|
|||
|
||||
private ProcessRecordsInput shardEndProcessRecordsInput;
|
||||
|
||||
private final ConsumerTaskFactory taskFactory;
|
||||
|
||||
public ShardConsumer(
|
||||
RecordsPublisher recordsPublisher,
|
||||
ExecutorService executorService,
|
||||
|
|
@ -103,7 +105,8 @@ public class ShardConsumer {
|
|||
ConsumerStates.INITIAL_STATE,
|
||||
8,
|
||||
taskExecutionListener,
|
||||
readTimeoutsToIgnoreBeforeWarning);
|
||||
readTimeoutsToIgnoreBeforeWarning,
|
||||
new KinesisConsumerTaskFactory());
|
||||
}
|
||||
|
||||
//
|
||||
|
|
@ -119,6 +122,30 @@ public class ShardConsumer {
|
|||
int bufferSize,
|
||||
TaskExecutionListener taskExecutionListener,
|
||||
int readTimeoutsToIgnoreBeforeWarning) {
|
||||
this(
|
||||
recordsPublisher,
|
||||
executorService,
|
||||
shardInfo,
|
||||
logWarningForTaskAfterMillis,
|
||||
shardConsumerArgument,
|
||||
initialState,
|
||||
bufferSize,
|
||||
taskExecutionListener,
|
||||
readTimeoutsToIgnoreBeforeWarning,
|
||||
new KinesisConsumerTaskFactory());
|
||||
}
|
||||
|
||||
public ShardConsumer(
|
||||
RecordsPublisher recordsPublisher,
|
||||
ExecutorService executorService,
|
||||
ShardInfo shardInfo,
|
||||
Optional<Long> logWarningForTaskAfterMillis,
|
||||
ShardConsumerArgument shardConsumerArgument,
|
||||
ConsumerState initialState,
|
||||
int bufferSize,
|
||||
TaskExecutionListener taskExecutionListener,
|
||||
int readTimeoutsToIgnoreBeforeWarning,
|
||||
ConsumerTaskFactory taskFactory) {
|
||||
this.recordsPublisher = recordsPublisher;
|
||||
this.executorService = executorService;
|
||||
this.shardInfo = shardInfo;
|
||||
|
|
@ -134,6 +161,7 @@ public class ShardConsumer {
|
|||
if (this.shardInfo.isCompleted()) {
|
||||
markForShutdown(ShutdownReason.SHARD_END);
|
||||
}
|
||||
this.taskFactory = taskFactory;
|
||||
}
|
||||
|
||||
synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) {
|
||||
|
|
@ -345,7 +373,7 @@ public class ShardConsumer {
|
|||
.taskType(currentState.taskType())
|
||||
.build();
|
||||
taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput);
|
||||
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input);
|
||||
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input, taskFactory);
|
||||
if (task != null) {
|
||||
taskDispatchedAt = Instant.now();
|
||||
currentTask = task;
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ public class ConsumerStatesTest {
|
|||
private long idleTimeInMillis = 1000L;
|
||||
private Optional<Long> logWarningForTaskAfterMillis = Optional.empty();
|
||||
private SchemaRegistryDecoder schemaRegistryDecoder = null;
|
||||
private final ConsumerTaskFactory taskFactory = new KinesisConsumerTaskFactory();
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
|
|
@ -177,7 +178,7 @@ public class ConsumerStatesTest {
|
|||
ConsumerState state = ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState();
|
||||
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
|
||||
|
||||
ConsumerTask task = state.createTask(argument, consumer, null);
|
||||
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
|
||||
|
||||
assertThat(task, taskWith(BlockOnParentShardTask.class, ShardInfo.class, "shardInfo", equalTo(shardInfo)));
|
||||
assertThat(
|
||||
|
|
@ -209,7 +210,7 @@ public class ConsumerStatesTest {
|
|||
@Test
|
||||
public void initializingStateTest() {
|
||||
ConsumerState state = ShardConsumerState.INITIALIZING.consumerState();
|
||||
ConsumerTask task = state.createTask(argument, consumer, null);
|
||||
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
|
||||
|
||||
assertThat(task, initTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
|
||||
assertThat(task, initTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
|
||||
|
|
@ -242,7 +243,7 @@ public class ConsumerStatesTest {
|
|||
public void processingStateTestSynchronous() {
|
||||
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
|
||||
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
|
||||
ConsumerTask task = state.createTask(argument, consumer, null);
|
||||
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
|
||||
|
||||
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
|
||||
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
|
||||
|
|
@ -274,7 +275,7 @@ public class ConsumerStatesTest {
|
|||
public void processingStateTestAsynchronous() {
|
||||
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
|
||||
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
|
||||
ConsumerTask task = state.createTask(argument, consumer, null);
|
||||
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
|
||||
|
||||
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
|
||||
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
|
||||
|
|
@ -306,7 +307,7 @@ public class ConsumerStatesTest {
|
|||
public void processingStateRecordsFetcher() {
|
||||
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
|
||||
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
|
||||
ConsumerTask task = state.createTask(argument, consumer, null);
|
||||
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
|
||||
|
||||
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
|
||||
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
|
||||
|
|
@ -339,7 +340,7 @@ public class ConsumerStatesTest {
|
|||
ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.consumerState();
|
||||
|
||||
consumer.gracefulShutdown(shutdownNotification);
|
||||
ConsumerTask task = state.createTask(argument, consumer, null);
|
||||
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
|
||||
|
||||
assertThat(
|
||||
task,
|
||||
|
|
@ -373,7 +374,7 @@ public class ConsumerStatesTest {
|
|||
public void shutdownRequestCompleteStateTest() {
|
||||
ConsumerState state = ConsumerStates.SHUTDOWN_REQUEST_COMPLETION_STATE;
|
||||
|
||||
assertThat(state.createTask(argument, consumer, null), nullValue());
|
||||
assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue());
|
||||
|
||||
assertThat(state.successTransition(), equalTo(state));
|
||||
|
||||
|
|
@ -409,7 +410,7 @@ public class ConsumerStatesTest {
|
|||
childShards.add(leftChild);
|
||||
childShards.add(rightChild);
|
||||
when(processRecordsInput.childShards()).thenReturn(childShards);
|
||||
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput);
|
||||
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput, taskFactory);
|
||||
|
||||
assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
|
||||
assertThat(
|
||||
|
|
@ -443,7 +444,7 @@ public class ConsumerStatesTest {
|
|||
|
||||
ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState();
|
||||
|
||||
assertThat(state.createTask(argument, consumer, null), nullValue());
|
||||
assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue());
|
||||
|
||||
assertThat(state.successTransition(), equalTo(state));
|
||||
for (ShutdownReason reason : ShutdownReason.values()) {
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ public class ShardConsumerTest {
|
|||
|
||||
verify(cache.subscription, times(3)).request(anyLong());
|
||||
verify(cache.subscription).cancel();
|
||||
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any());
|
||||
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
|
||||
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
|
||||
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
|
||||
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
|
||||
|
|
@ -394,7 +394,7 @@ public class ShardConsumerTest {
|
|||
|
||||
verify(cache.subscription, times(1)).request(anyLong());
|
||||
verify(cache.subscription).cancel();
|
||||
verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any());
|
||||
verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
|
||||
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
|
||||
verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput);
|
||||
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
|
||||
|
|
@ -437,14 +437,14 @@ public class ShardConsumerTest {
|
|||
cache.publish();
|
||||
awaitAndResetBarrier(taskCallBarrier);
|
||||
|
||||
verify(processingState).createTask(any(), any(), any());
|
||||
verify(processingState).createTask(any(), any(), any(), any());
|
||||
verify(processingTask).call();
|
||||
|
||||
cache.awaitRequest();
|
||||
|
||||
cache.publish();
|
||||
awaitAndResetBarrier(taskCallBarrier);
|
||||
verify(processingState, times(2)).createTask(any(), any(), any());
|
||||
verify(processingState, times(2)).createTask(any(), any(), any(), any());
|
||||
verify(processingTask, times(2)).call();
|
||||
|
||||
cache.awaitRequest();
|
||||
|
|
@ -460,7 +460,7 @@ public class ShardConsumerTest {
|
|||
shutdownComplete = consumer.shutdownComplete().get();
|
||||
} while (!shutdownComplete);
|
||||
|
||||
verify(processingState, times(3)).createTask(any(), any(), any());
|
||||
verify(processingState, times(3)).createTask(any(), any(), any(), any());
|
||||
verify(processingTask, times(3)).call();
|
||||
verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
|
||||
verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
|
||||
|
|
@ -487,7 +487,7 @@ public class ShardConsumerTest {
|
|||
public final void testInitializationStateUponFailure() throws Exception {
|
||||
final ShardConsumer consumer = createShardConsumer(recordsPublisher);
|
||||
|
||||
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any()))
|
||||
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any(), any()))
|
||||
.thenReturn(initializeTask);
|
||||
when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad")));
|
||||
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
|
||||
|
|
@ -505,7 +505,7 @@ public class ShardConsumerTest {
|
|||
awaitAndResetBarrier(taskBarrier);
|
||||
}
|
||||
|
||||
verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any());
|
||||
verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
|
||||
verify(initialState, never()).successTransition();
|
||||
verify(initialState, never()).shutdownTransition(any());
|
||||
}
|
||||
|
|
@ -665,7 +665,7 @@ public class ShardConsumerTest {
|
|||
public void testErrorThrowableInInitialization() throws Exception {
|
||||
final ShardConsumer consumer = createShardConsumer(recordsPublisher);
|
||||
|
||||
when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask);
|
||||
when(initialState.createTask(any(), any(), any(), any())).thenReturn(initializeTask);
|
||||
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
|
||||
when(initializeTask.call()).thenAnswer(i -> {
|
||||
throw new Error("Error");
|
||||
|
|
@ -692,13 +692,13 @@ public class ShardConsumerTest {
|
|||
mockSuccessfulProcessing(taskBarrier);
|
||||
|
||||
when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState);
|
||||
when(shutdownRequestedState.createTask(any(), any(), any())).thenReturn(shutdownRequestedTask);
|
||||
when(shutdownRequestedState.createTask(any(), any(), any(), any())).thenReturn(shutdownRequestedTask);
|
||||
when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION);
|
||||
when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null));
|
||||
|
||||
when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED)))
|
||||
.thenReturn(shutdownRequestedAwaitState);
|
||||
when(shutdownRequestedAwaitState.createTask(any(), any(), any())).thenReturn(null);
|
||||
when(shutdownRequestedAwaitState.createTask(any(), any(), any(), any())).thenReturn(null);
|
||||
when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST)))
|
||||
.thenReturn(shutdownState);
|
||||
when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE);
|
||||
|
|
@ -733,11 +733,11 @@ public class ShardConsumerTest {
|
|||
shutdownComplete = consumer.shutdownComplete().get();
|
||||
assertTrue(shutdownComplete);
|
||||
|
||||
verify(processingState, times(2)).createTask(any(), any(), any());
|
||||
verify(processingState, times(2)).createTask(any(), any(), any(), any());
|
||||
verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
|
||||
verify(shutdownRequestedState).createTask(any(), any(), any());
|
||||
verify(shutdownRequestedState).createTask(any(), any(), any(), any());
|
||||
verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED));
|
||||
verify(shutdownRequestedAwaitState).createTask(any(), any(), any());
|
||||
verify(shutdownRequestedAwaitState).createTask(any(), any(), any(), any());
|
||||
verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
|
||||
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
|
||||
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
|
||||
|
|
@ -948,7 +948,7 @@ public class ShardConsumerTest {
|
|||
when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
|
||||
when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
|
||||
final ConsumerTask mockTask = mock(ConsumerTask.class);
|
||||
when(mockState.createTask(any(), any(), any())).thenReturn(mockTask);
|
||||
when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockTask);
|
||||
// Simulate successful BlockedOnParent task execution
|
||||
// and successful Initialize task execution
|
||||
when(mockTask.call()).thenReturn(new TaskResult(false));
|
||||
|
|
@ -993,7 +993,7 @@ public class ShardConsumerTest {
|
|||
reset(mockState);
|
||||
when(mockState.taskType()).thenReturn(TaskType.PROCESS);
|
||||
final ConsumerTask mockProcessTask = mock(ConsumerTask.class);
|
||||
when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask);
|
||||
when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockProcessTask);
|
||||
when(mockProcessTask.call()).then(input -> {
|
||||
// first we want to wait for subscribe to be called,
|
||||
// but we cannot control the timing, so wait for 10 seconds
|
||||
|
|
@ -1045,7 +1045,8 @@ public class ShardConsumerTest {
|
|||
}
|
||||
|
||||
private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) {
|
||||
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(shutdownTask);
|
||||
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any(), any()))
|
||||
.thenReturn(shutdownTask);
|
||||
when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN);
|
||||
when(shutdownTask.call()).thenAnswer(i -> {
|
||||
awaitBarrier(taskArriveBarrier);
|
||||
|
|
@ -1063,7 +1064,7 @@ public class ShardConsumerTest {
|
|||
}
|
||||
|
||||
private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
|
||||
when(processingState.createTask(eq(shardConsumerArgument), any(), any()))
|
||||
when(processingState.createTask(eq(shardConsumerArgument), any(), any(), any()))
|
||||
.thenReturn(processingTask);
|
||||
when(processingState.taskType()).thenReturn(TaskType.PROCESS);
|
||||
when(processingTask.taskType()).thenReturn(TaskType.PROCESS);
|
||||
|
|
@ -1088,7 +1089,8 @@ public class ShardConsumerTest {
|
|||
}
|
||||
|
||||
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
|
||||
when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask);
|
||||
when(initialState.createTask(eq(shardConsumerArgument), any(), any(), any()))
|
||||
.thenReturn(initializeTask);
|
||||
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
|
||||
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
|
||||
when(initializeTask.call()).thenAnswer(i -> {
|
||||
|
|
@ -1107,7 +1109,7 @@ public class ShardConsumerTest {
|
|||
}
|
||||
|
||||
private void mockSuccessfulUnblockOnParents() {
|
||||
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any()))
|
||||
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any(), any()))
|
||||
.thenReturn(blockedOnParentsTask);
|
||||
when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
|
||||
when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult);
|
||||
|
|
|
|||
Loading…
Reference in a new issue