From 17b59abada85f90b4e08e596a6efab3d27853609 Mon Sep 17 00:00:00 2001 From: gguptp Date: Thu, 20 Feb 2025 15:44:10 +0530 Subject: [PATCH] Adding functionality to override the vanilla KCL tasks --- .../amazon/kinesis/coordinator/Scheduler.java | 31 ++++++ .../kinesis/lifecycle/ConsumerState.java | 10 +- .../kinesis/lifecycle/ConsumerStates.java | 94 +++++++----------- .../lifecycle/ConsumerTaskFactory.java | 47 +++++++++ .../lifecycle/KinesisConsumerTaskFactory.java | 98 +++++++++++++++++++ .../kinesis/lifecycle/ShardConsumer.java | 32 +++++- .../kinesis/lifecycle/ConsumerStatesTest.java | 19 ++-- .../kinesis/lifecycle/ShardConsumerTest.java | 40 ++++---- 8 files changed, 278 insertions(+), 93 deletions(-) create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerTaskFactory.java create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/KinesisConsumerTaskFactory.java diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java index 0adb69f9..2382b4e1 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java @@ -87,6 +87,8 @@ import software.amazon.kinesis.leases.dynamodb.DynamoDBMultiStreamLeaseSerialize import software.amazon.kinesis.leases.exceptions.DependencyException; import software.amazon.kinesis.leases.exceptions.InvalidStateException; import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException; +import software.amazon.kinesis.lifecycle.ConsumerTaskFactory; +import software.amazon.kinesis.lifecycle.KinesisConsumerTaskFactory; import software.amazon.kinesis.lifecycle.LifecycleConfig; import software.amazon.kinesis.lifecycle.ShardConsumer; import software.amazon.kinesis.lifecycle.ShardConsumerArgument; @@ -188,6 +190,7 @@ public class Scheduler implements Runnable { private final SchemaRegistryDecoder schemaRegistryDecoder; private final DeletedStreamListProvider deletedStreamListProvider; + private final ConsumerTaskFactory taskFactory; @Getter(AccessLevel.NONE) private final MigrationStateMachine migrationStateMachine; @@ -264,6 +267,33 @@ public class Scheduler implements Runnable { @NonNull final ProcessorConfig processorConfig, @NonNull final RetrievalConfig retrievalConfig, @NonNull final DiagnosticEventFactory diagnosticEventFactory) { + this( + checkpointConfig, + coordinatorConfig, + leaseManagementConfig, + lifecycleConfig, + metricsConfig, + processorConfig, + retrievalConfig, + diagnosticEventFactory, + new KinesisConsumerTaskFactory()); + } + + /** + * Customers do not currently have the ability to customize the DiagnosticEventFactory, but this visibility + * is desired for testing. This constructor is only used for testing to provide a mock DiagnosticEventFactory. + */ + @VisibleForTesting + protected Scheduler( + @NonNull final CheckpointConfig checkpointConfig, + @NonNull final CoordinatorConfig coordinatorConfig, + @NonNull final LeaseManagementConfig leaseManagementConfig, + @NonNull final LifecycleConfig lifecycleConfig, + @NonNull final MetricsConfig metricsConfig, + @NonNull final ProcessorConfig processorConfig, + @NonNull final RetrievalConfig retrievalConfig, + @NonNull final DiagnosticEventFactory diagnosticEventFactory, + @NonNull final ConsumerTaskFactory taskFactory) { this.checkpointConfig = checkpointConfig; this.coordinatorConfig = coordinatorConfig; this.leaseManagementConfig = leaseManagementConfig; @@ -371,6 +401,7 @@ public class Scheduler implements Runnable { this.schemaRegistryDecoder = this.retrievalConfig.glueSchemaRegistryDeserializer() == null ? null : new SchemaRegistryDecoder(this.retrievalConfig.glueSchemaRegistryDeserializer()); + this.taskFactory = taskFactory; } /** diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerState.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerState.java index 3aa03b11..fb3f1407 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerState.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerState.java @@ -34,13 +34,19 @@ interface ConsumerState { * the consumer to use build the task, or execute state. * @param input * the process input received, this may be null if it's a control message + * @param taskFactory + * a factory for creating tasks * @return a valid task for this state or null if there is no task required. */ - ConsumerTask createTask(ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input); + ConsumerTask createTask( + ShardConsumerArgument consumerArgument, + ShardConsumer consumer, + ProcessRecordsInput input, + ConsumerTaskFactory taskFactory); /** * Provides the next state of the consumer upon success of the task return by - * {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput)}. + * {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput, ConsumerTaskFactory)}. * * @return the next state that the consumer should transition to, this may be the same object as the current * state. diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java index eb1a8f48..441705d2 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java @@ -17,7 +17,6 @@ package software.amazon.kinesis.lifecycle; import lombok.Getter; import lombok.experimental.Accessors; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; -import software.amazon.kinesis.retrieval.ThrottlingReporter; /** * Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks, @@ -121,11 +120,11 @@ class ConsumerStates { @Override public ConsumerTask createTask( - ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input) { - return new BlockOnParentShardTask( - consumerArgument.shardInfo(), - consumerArgument.leaseCoordinator().leaseRefresher(), - consumerArgument.parentShardPollIntervalMillis()); + ShardConsumerArgument consumerArgument, + ShardConsumer consumer, + ProcessRecordsInput input, + ConsumerTaskFactory taskFactory) { + return taskFactory.createBlockOnParentTask(consumerArgument); } @Override @@ -187,16 +186,11 @@ class ConsumerStates { @Override public ConsumerTask createTask( - ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { - return new InitializeTask( - argument.shardInfo(), - argument.shardRecordProcessor(), - argument.checkpoint(), - argument.recordProcessorCheckpointer(), - argument.initialPositionInStream(), - argument.recordsPublisher(), - argument.taskBackoffTimeMillis(), - argument.metricsFactory()); + ShardConsumerArgument argument, + ShardConsumer consumer, + ProcessRecordsInput input, + ConsumerTaskFactory taskFactory) { + return taskFactory.createInitializeTask(argument); } @Override @@ -250,24 +244,11 @@ class ConsumerStates { @Override public ConsumerTask createTask( - ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { - ThrottlingReporter throttlingReporter = - new ThrottlingReporter(5, argument.shardInfo().shardId()); - return new ProcessTask( - argument.shardInfo(), - argument.shardRecordProcessor(), - argument.recordProcessorCheckpointer(), - argument.taskBackoffTimeMillis(), - argument.skipShardSyncAtWorkerInitializationIfLeasesExist(), - argument.shardDetector(), - throttlingReporter, - input, - argument.shouldCallProcessRecordsEvenForEmptyRecordList(), - argument.idleTimeInMilliseconds(), - argument.aggregatorUtil(), - argument.metricsFactory(), - argument.schemaRegistryDecoder(), - argument.leaseCoordinator().leaseStatsRecorder()); + ShardConsumerArgument argument, + ShardConsumer consumer, + ProcessRecordsInput input, + ConsumerTaskFactory taskFactory) { + return taskFactory.createProcessTask(argument, input); } @Override @@ -331,14 +312,12 @@ class ConsumerStates { @Override public ConsumerTask createTask( - ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { + ShardConsumerArgument argument, + ShardConsumer consumer, + ProcessRecordsInput input, + ConsumerTaskFactory taskFactory) { // TODO: notify shutdownrequested - return new ShutdownNotificationTask( - argument.shardRecordProcessor(), - argument.recordProcessorCheckpointer(), - consumer.shutdownNotification(), - argument.shardInfo(), - consumer.shardConsumerArgument().leaseCoordinator()); + return taskFactory.createShutdownNotificationTask(argument, consumer); } @Override @@ -405,7 +384,10 @@ class ConsumerStates { @Override public ConsumerTask createTask( - ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { + ShardConsumerArgument argument, + ShardConsumer consumer, + ProcessRecordsInput input, + ConsumerTaskFactory taskFactory) { return null; } @@ -483,25 +465,12 @@ class ConsumerStates { @Override public ConsumerTask createTask( - ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { + ShardConsumerArgument argument, + ShardConsumer consumer, + ProcessRecordsInput input, + ConsumerTaskFactory taskFactory) { // TODO: set shutdown reason - return new ShutdownTask( - argument.shardInfo(), - argument.shardDetector(), - argument.shardRecordProcessor(), - argument.recordProcessorCheckpointer(), - consumer.shutdownReason(), - argument.initialPositionInStream(), - argument.cleanupLeasesOfCompletedShards(), - argument.ignoreUnexpectedChildShards(), - argument.leaseCoordinator(), - argument.taskBackoffTimeMillis(), - argument.recordsPublisher(), - argument.hierarchicalShardSyncer(), - argument.metricsFactory(), - input == null ? null : input.childShards(), - argument.streamIdentifier(), - argument.leaseCleanupManager()); + return taskFactory.createShutdownTask(argument, consumer, input); } @Override @@ -569,7 +538,10 @@ class ConsumerStates { @Override public ConsumerTask createTask( - ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { + ShardConsumerArgument argument, + ShardConsumer consumer, + ProcessRecordsInput input, + ConsumerTaskFactory taskFactory) { return null; } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerTaskFactory.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerTaskFactory.java new file mode 100644 index 00000000..5f65241e --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerTaskFactory.java @@ -0,0 +1,47 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. + * Licensed under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.kinesis.lifecycle; + +import software.amazon.kinesis.annotations.KinesisClientInternalApi; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; + +@KinesisClientInternalApi +public interface ConsumerTaskFactory { + /** + * Creates a shutdown task. + */ + ConsumerTask createShutdownTask(ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input); + + /** + * Creates a process task. + */ + ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput); + + /** + * Creates an initialize task. + */ + ConsumerTask createInitializeTask(ShardConsumerArgument argument); + + /** + * Creates a block on parent task. + */ + ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument); + + /** + * Creates a shutdown notification task. + */ + ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer); +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/KinesisConsumerTaskFactory.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/KinesisConsumerTaskFactory.java new file mode 100644 index 00000000..0c871e12 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/KinesisConsumerTaskFactory.java @@ -0,0 +1,98 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. + * Licensed under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.kinesis.lifecycle; + +import software.amazon.kinesis.annotations.KinesisClientInternalApi; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; +import software.amazon.kinesis.retrieval.ThrottlingReporter; + +@KinesisClientInternalApi +public class KinesisConsumerTaskFactory implements ConsumerTaskFactory { + + @Override + public ConsumerTask createShutdownTask( + ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { + return new ShutdownTask( + argument.shardInfo(), + argument.shardDetector(), + argument.shardRecordProcessor(), + argument.recordProcessorCheckpointer(), + consumer.shutdownReason(), + argument.initialPositionInStream(), + argument.cleanupLeasesOfCompletedShards(), + argument.ignoreUnexpectedChildShards(), + argument.leaseCoordinator(), + argument.taskBackoffTimeMillis(), + argument.recordsPublisher(), + argument.hierarchicalShardSyncer(), + argument.metricsFactory(), + input == null ? null : input.childShards(), + argument.streamIdentifier(), + argument.leaseCleanupManager()); + } + + @Override + public ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput) { + ThrottlingReporter throttlingReporter = + new ThrottlingReporter(5, argument.shardInfo().shardId()); + return new ProcessTask( + argument.shardInfo(), + argument.shardRecordProcessor(), + argument.recordProcessorCheckpointer(), + argument.taskBackoffTimeMillis(), + argument.skipShardSyncAtWorkerInitializationIfLeasesExist(), + argument.shardDetector(), + throttlingReporter, + processRecordsInput, + argument.shouldCallProcessRecordsEvenForEmptyRecordList(), + argument.idleTimeInMilliseconds(), + argument.aggregatorUtil(), + argument.metricsFactory(), + argument.schemaRegistryDecoder(), + argument.leaseCoordinator().leaseStatsRecorder()); + } + + @Override + public ConsumerTask createInitializeTask(ShardConsumerArgument argument) { + return new InitializeTask( + argument.shardInfo(), + argument.shardRecordProcessor(), + argument.checkpoint(), + argument.recordProcessorCheckpointer(), + argument.initialPositionInStream(), + argument.recordsPublisher(), + argument.taskBackoffTimeMillis(), + argument.metricsFactory()); + } + + @Override + public ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument) { + return new BlockOnParentShardTask( + argument.shardInfo(), + argument.leaseCoordinator().leaseRefresher(), + argument.parentShardPollIntervalMillis()); + } + + @Override + public ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer) { + return new ShutdownNotificationTask( + argument.shardRecordProcessor(), + argument.recordProcessorCheckpointer(), + consumer.shutdownNotification(), + argument.shardInfo(), + consumer.shardConsumerArgument().leaseCoordinator()); + } +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java index 2e519ee1..a23732f7 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java @@ -86,6 +86,8 @@ public class ShardConsumer { private ProcessRecordsInput shardEndProcessRecordsInput; + private final ConsumerTaskFactory taskFactory; + public ShardConsumer( RecordsPublisher recordsPublisher, ExecutorService executorService, @@ -103,7 +105,8 @@ public class ShardConsumer { ConsumerStates.INITIAL_STATE, 8, taskExecutionListener, - readTimeoutsToIgnoreBeforeWarning); + readTimeoutsToIgnoreBeforeWarning, + new KinesisConsumerTaskFactory()); } // @@ -119,6 +122,30 @@ public class ShardConsumer { int bufferSize, TaskExecutionListener taskExecutionListener, int readTimeoutsToIgnoreBeforeWarning) { + this( + recordsPublisher, + executorService, + shardInfo, + logWarningForTaskAfterMillis, + shardConsumerArgument, + initialState, + bufferSize, + taskExecutionListener, + readTimeoutsToIgnoreBeforeWarning, + new KinesisConsumerTaskFactory()); + } + + public ShardConsumer( + RecordsPublisher recordsPublisher, + ExecutorService executorService, + ShardInfo shardInfo, + Optional logWarningForTaskAfterMillis, + ShardConsumerArgument shardConsumerArgument, + ConsumerState initialState, + int bufferSize, + TaskExecutionListener taskExecutionListener, + int readTimeoutsToIgnoreBeforeWarning, + ConsumerTaskFactory taskFactory) { this.recordsPublisher = recordsPublisher; this.executorService = executorService; this.shardInfo = shardInfo; @@ -134,6 +161,7 @@ public class ShardConsumer { if (this.shardInfo.isCompleted()) { markForShutdown(ShutdownReason.SHARD_END); } + this.taskFactory = taskFactory; } synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) { @@ -345,7 +373,7 @@ public class ShardConsumer { .taskType(currentState.taskType()) .build(); taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput); - ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input); + ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input, taskFactory); if (task != null) { taskDispatchedAt = Instant.now(); currentTask = task; diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java index 9491a97f..cc41c479 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java @@ -127,6 +127,7 @@ public class ConsumerStatesTest { private long idleTimeInMillis = 1000L; private Optional logWarningForTaskAfterMillis = Optional.empty(); private SchemaRegistryDecoder schemaRegistryDecoder = null; + private final ConsumerTaskFactory taskFactory = new KinesisConsumerTaskFactory(); @Before public void setup() { @@ -177,7 +178,7 @@ public class ConsumerStatesTest { ConsumerState state = ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState(); when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher); - ConsumerTask task = state.createTask(argument, consumer, null); + ConsumerTask task = state.createTask(argument, consumer, null, taskFactory); assertThat(task, taskWith(BlockOnParentShardTask.class, ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat( @@ -209,7 +210,7 @@ public class ConsumerStatesTest { @Test public void initializingStateTest() { ConsumerState state = ShardConsumerState.INITIALIZING.consumerState(); - ConsumerTask task = state.createTask(argument, consumer, null); + ConsumerTask task = state.createTask(argument, consumer, null, taskFactory); assertThat(task, initTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, initTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor))); @@ -242,7 +243,7 @@ public class ConsumerStatesTest { public void processingStateTestSynchronous() { when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder); ConsumerState state = ShardConsumerState.PROCESSING.consumerState(); - ConsumerTask task = state.createTask(argument, consumer, null); + ConsumerTask task = state.createTask(argument, consumer, null, taskFactory); assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor))); @@ -274,7 +275,7 @@ public class ConsumerStatesTest { public void processingStateTestAsynchronous() { when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder); ConsumerState state = ShardConsumerState.PROCESSING.consumerState(); - ConsumerTask task = state.createTask(argument, consumer, null); + ConsumerTask task = state.createTask(argument, consumer, null, taskFactory); assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor))); @@ -306,7 +307,7 @@ public class ConsumerStatesTest { public void processingStateRecordsFetcher() { when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder); ConsumerState state = ShardConsumerState.PROCESSING.consumerState(); - ConsumerTask task = state.createTask(argument, consumer, null); + ConsumerTask task = state.createTask(argument, consumer, null, taskFactory); assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor))); @@ -339,7 +340,7 @@ public class ConsumerStatesTest { ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.consumerState(); consumer.gracefulShutdown(shutdownNotification); - ConsumerTask task = state.createTask(argument, consumer, null); + ConsumerTask task = state.createTask(argument, consumer, null, taskFactory); assertThat( task, @@ -373,7 +374,7 @@ public class ConsumerStatesTest { public void shutdownRequestCompleteStateTest() { ConsumerState state = ConsumerStates.SHUTDOWN_REQUEST_COMPLETION_STATE; - assertThat(state.createTask(argument, consumer, null), nullValue()); + assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue()); assertThat(state.successTransition(), equalTo(state)); @@ -409,7 +410,7 @@ public class ConsumerStatesTest { childShards.add(leftChild); childShards.add(rightChild); when(processRecordsInput.childShards()).thenReturn(childShards); - ConsumerTask task = state.createTask(argument, consumer, processRecordsInput); + ConsumerTask task = state.createTask(argument, consumer, processRecordsInput, taskFactory); assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat( @@ -443,7 +444,7 @@ public class ConsumerStatesTest { ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState(); - assertThat(state.createTask(argument, consumer, null), nullValue()); + assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue()); assertThat(state.successTransition(), equalTo(state)); for (ShutdownReason reason : ShutdownReason.values()) { diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java index 83b27ba7..6390831f 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java @@ -332,7 +332,7 @@ public class ShardConsumerTest { verify(cache.subscription, times(3)).request(anyLong()); verify(cache.subscription).cancel(); - verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any()); + verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any()); verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput); @@ -394,7 +394,7 @@ public class ShardConsumerTest { verify(cache.subscription, times(1)).request(anyLong()); verify(cache.subscription).cancel(); - verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any()); + verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any()); verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput); @@ -437,14 +437,14 @@ public class ShardConsumerTest { cache.publish(); awaitAndResetBarrier(taskCallBarrier); - verify(processingState).createTask(any(), any(), any()); + verify(processingState).createTask(any(), any(), any(), any()); verify(processingTask).call(); cache.awaitRequest(); cache.publish(); awaitAndResetBarrier(taskCallBarrier); - verify(processingState, times(2)).createTask(any(), any(), any()); + verify(processingState, times(2)).createTask(any(), any(), any(), any()); verify(processingTask, times(2)).call(); cache.awaitRequest(); @@ -460,7 +460,7 @@ public class ShardConsumerTest { shutdownComplete = consumer.shutdownComplete().get(); } while (!shutdownComplete); - verify(processingState, times(3)).createTask(any(), any(), any()); + verify(processingState, times(3)).createTask(any(), any(), any(), any()); verify(processingTask, times(3)).call(); verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); @@ -487,7 +487,7 @@ public class ShardConsumerTest { public final void testInitializationStateUponFailure() throws Exception { final ShardConsumer consumer = createShardConsumer(recordsPublisher); - when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any())) + when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any(), any())) .thenReturn(initializeTask); when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad"))); when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE); @@ -505,7 +505,7 @@ public class ShardConsumerTest { awaitAndResetBarrier(taskBarrier); } - verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any()); + verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any()); verify(initialState, never()).successTransition(); verify(initialState, never()).shutdownTransition(any()); } @@ -665,7 +665,7 @@ public class ShardConsumerTest { public void testErrorThrowableInInitialization() throws Exception { final ShardConsumer consumer = createShardConsumer(recordsPublisher); - when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask); + when(initialState.createTask(any(), any(), any(), any())).thenReturn(initializeTask); when(initialState.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.call()).thenAnswer(i -> { throw new Error("Error"); @@ -692,13 +692,13 @@ public class ShardConsumerTest { mockSuccessfulProcessing(taskBarrier); when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState); - when(shutdownRequestedState.createTask(any(), any(), any())).thenReturn(shutdownRequestedTask); + when(shutdownRequestedState.createTask(any(), any(), any(), any())).thenReturn(shutdownRequestedTask); when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION); when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null)); when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED))) .thenReturn(shutdownRequestedAwaitState); - when(shutdownRequestedAwaitState.createTask(any(), any(), any())).thenReturn(null); + when(shutdownRequestedAwaitState.createTask(any(), any(), any(), any())).thenReturn(null); when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST))) .thenReturn(shutdownState); when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE); @@ -733,11 +733,11 @@ public class ShardConsumerTest { shutdownComplete = consumer.shutdownComplete().get(); assertTrue(shutdownComplete); - verify(processingState, times(2)).createTask(any(), any(), any()); + verify(processingState, times(2)).createTask(any(), any(), any(), any()); verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); - verify(shutdownRequestedState).createTask(any(), any(), any()); + verify(shutdownRequestedState).createTask(any(), any(), any(), any()); verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED)); - verify(shutdownRequestedAwaitState).createTask(any(), any(), any()); + verify(shutdownRequestedAwaitState).createTask(any(), any(), any(), any()); verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput); @@ -948,7 +948,7 @@ public class ShardConsumerTest { when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS); when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS); final ConsumerTask mockTask = mock(ConsumerTask.class); - when(mockState.createTask(any(), any(), any())).thenReturn(mockTask); + when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockTask); // Simulate successful BlockedOnParent task execution // and successful Initialize task execution when(mockTask.call()).thenReturn(new TaskResult(false)); @@ -993,7 +993,7 @@ public class ShardConsumerTest { reset(mockState); when(mockState.taskType()).thenReturn(TaskType.PROCESS); final ConsumerTask mockProcessTask = mock(ConsumerTask.class); - when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask); + when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockProcessTask); when(mockProcessTask.call()).then(input -> { // first we want to wait for subscribe to be called, // but we cannot control the timing, so wait for 10 seconds @@ -1045,7 +1045,8 @@ public class ShardConsumerTest { } private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) { - when(shutdownState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(shutdownTask); + when(shutdownState.createTask(eq(shardConsumerArgument), any(), any(), any())) + .thenReturn(shutdownTask); when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN); when(shutdownTask.call()).thenAnswer(i -> { awaitBarrier(taskArriveBarrier); @@ -1063,7 +1064,7 @@ public class ShardConsumerTest { } private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) { - when(processingState.createTask(eq(shardConsumerArgument), any(), any())) + when(processingState.createTask(eq(shardConsumerArgument), any(), any(), any())) .thenReturn(processingTask); when(processingState.taskType()).thenReturn(TaskType.PROCESS); when(processingTask.taskType()).thenReturn(TaskType.PROCESS); @@ -1088,7 +1089,8 @@ public class ShardConsumerTest { } private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) { - when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask); + when(initialState.createTask(eq(shardConsumerArgument), any(), any(), any())) + .thenReturn(initializeTask); when(initialState.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.call()).thenAnswer(i -> { @@ -1107,7 +1109,7 @@ public class ShardConsumerTest { } private void mockSuccessfulUnblockOnParents() { - when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any())) + when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any(), any())) .thenReturn(blockedOnParentsTask); when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS); when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult);