Merge branch 'awslabs:master' into master

This commit is contained in:
vincentvilo-aws 2025-03-11 11:08:31 -07:00 committed by GitHub
commit 31cba5f140
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 268 additions and 99 deletions

View file

@ -55,6 +55,14 @@ To make it easier for developers to write record processors in other languages,
## Using the KCL
The recommended way to use the KCL for Java is to consume it from Maven.
## 🚨Important: Do not use AWS SDK for Java versions 2.27.19 to 2.27.23 with KCL 3.x
When using KCL 3.x with AWS SDK for Java versions 2.27.19 through 2.27.23, you may encounter the following DynamoDB exception:
```
software.amazon.awssdk.services.dynamodb.model.DynamoDbException: The document path provided in the update expression is invalid for update (Service: DynamoDb, Status Code: 400, Request ID: xxx)
```
This error occurs due to [a known issue](https://github.com/aws/aws-sdk-java-v2/issues/5584) in the AWS SDK for Java that affects the DynamoDB metadata table managed by KCL 3.x. The issue was introduced in version 2.27.19 and impacts all versions up to 2.27.23. The issue has been resolved in the AWS SDK for Java version 2.27.24. For optimal performance and stability, we recommend upgrading to version 2.28.0 or later.
### Version 3.x
``` xml
<dependency>
@ -70,7 +78,7 @@ The recommended way to use the KCL for Java is to consume it from Maven.
<dependency>
<groupId>software.amazon.kinesis</groupId>
<artifactId>amazon-kinesis-client</artifactId>
<version>2.6.0</version>
<version>2.6.1</version>
</dependency>
```
@ -127,4 +135,4 @@ By participating through these channels, you play a vital role in shaping the fu
[migration-guide]: https://docs.aws.amazon.com/streams/latest/dev/kcl-migration-from-previous-versions
[kcl-sample]: https://docs.aws.amazon.com/streams/latest/dev/kcl-example-code
[kcl-aws-doc]: https://docs.aws.amazon.com/streams/latest/dev/kcl.html
[giving-feedback]: https://github.com/awslabs/amazon-kinesis-client?tab=readme-ov-file#giving-feedback
[giving-feedback]: https://github.com/awslabs/amazon-kinesis-client?tab=readme-ov-file#giving-feedback

View file

@ -171,7 +171,7 @@
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
<version>4.1.108.Final</version>
<version>4.1.118.Final</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
@ -181,7 +181,7 @@
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.10.1</version>
<version>2.12.7.1</version>
</dependency>
<dependency>
<groupId>org.reactivestreams</groupId>

View file

@ -87,6 +87,7 @@ import software.amazon.kinesis.leases.dynamodb.DynamoDBMultiStreamLeaseSerialize
import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
import software.amazon.kinesis.lifecycle.ConsumerTaskFactory;
import software.amazon.kinesis.lifecycle.LifecycleConfig;
import software.amazon.kinesis.lifecycle.ShardConsumer;
import software.amazon.kinesis.lifecycle.ShardConsumerArgument;
@ -188,6 +189,7 @@ public class Scheduler implements Runnable {
private final SchemaRegistryDecoder schemaRegistryDecoder;
private final DeletedStreamListProvider deletedStreamListProvider;
private final ConsumerTaskFactory taskFactory;
@Getter(AccessLevel.NONE)
private final MigrationStateMachine migrationStateMachine;
@ -371,6 +373,7 @@ public class Scheduler implements Runnable {
this.schemaRegistryDecoder = this.retrievalConfig.glueSchemaRegistryDeserializer() == null
? null
: new SchemaRegistryDecoder(this.retrievalConfig.glueSchemaRegistryDeserializer());
this.taskFactory = leaseManagementConfig().consumerTaskFactory();
}
/**

View file

@ -537,8 +537,8 @@ public final class LeaseAssignmentManager {
.filter(workerMetrics -> !workerMetrics.isValidWorkerMetric())
.map(WorkerMetricStats::getWorkerId)
.collect(Collectors.toList());
log.warn("List of workerIds with invalid entries : {}", listOfWorkerIdOfInvalidWorkerMetricsEntry);
if (!listOfWorkerIdOfInvalidWorkerMetricsEntry.isEmpty()) {
log.warn("List of workerIds with invalid entries : {}", listOfWorkerIdOfInvalidWorkerMetricsEntry);
metricsScope.addData(
"NumWorkersWithInvalidEntry",
listOfWorkerIdOfInvalidWorkerMetricsEntry.size(),
@ -567,8 +567,8 @@ public final class LeaseAssignmentManager {
final Map.Entry<List<Lease>, List<String>> leaseListResponse = leaseListFuture.join();
this.leaseList = leaseListResponse.getKey();
log.warn("Leases that failed deserialization : {}", leaseListResponse.getValue());
if (!leaseListResponse.getValue().isEmpty()) {
log.warn("Leases that failed deserialization : {}", leaseListResponse.getValue());
MetricsUtil.addCount(
metricsScope,
"LeaseDeserializationFailureCount",

View file

@ -45,6 +45,8 @@ import software.amazon.kinesis.common.StreamConfig;
import software.amazon.kinesis.leases.dynamodb.DynamoDBLeaseManagementFactory;
import software.amazon.kinesis.leases.dynamodb.DynamoDBLeaseSerializer;
import software.amazon.kinesis.leases.dynamodb.TableCreatorCallback;
import software.amazon.kinesis.lifecycle.ConsumerTaskFactory;
import software.amazon.kinesis.lifecycle.KinesisConsumerTaskFactory;
import software.amazon.kinesis.metrics.MetricsFactory;
import software.amazon.kinesis.metrics.NullMetricsFactory;
import software.amazon.kinesis.worker.metric.WorkerMetric;
@ -215,6 +217,8 @@ public class LeaseManagementConfig {
private BillingMode billingMode = BillingMode.PAY_PER_REQUEST;
private ConsumerTaskFactory consumerTaskFactory = new KinesisConsumerTaskFactory();
private WorkerUtilizationAwareAssignmentConfig workerUtilizationAwareAssignmentConfig =
new WorkerUtilizationAwareAssignmentConfig();

View file

@ -34,13 +34,19 @@ interface ConsumerState {
* the consumer to use build the task, or execute state.
* @param input
* the process input received, this may be null if it's a control message
* @param taskFactory
* a factory for creating tasks
* @return a valid task for this state or null if there is no task required.
*/
ConsumerTask createTask(ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input);
ConsumerTask createTask(
ShardConsumerArgument consumerArgument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory);
/**
* Provides the next state of the consumer upon success of the task return by
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput)}.
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput, ConsumerTaskFactory)}.
*
* @return the next state that the consumer should transition to, this may be the same object as the current
* state.

View file

@ -17,7 +17,6 @@ package software.amazon.kinesis.lifecycle;
import lombok.Getter;
import lombok.experimental.Accessors;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.ThrottlingReporter;
/**
* Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks,
@ -121,11 +120,11 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input) {
return new BlockOnParentShardTask(
consumerArgument.shardInfo(),
consumerArgument.leaseCoordinator().leaseRefresher(),
consumerArgument.parentShardPollIntervalMillis());
ShardConsumerArgument consumerArgument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createBlockOnParentTask(consumerArgument);
}
@Override
@ -187,16 +186,11 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
return new InitializeTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.checkpoint(),
argument.recordProcessorCheckpointer(),
argument.initialPositionInStream(),
argument.recordsPublisher(),
argument.taskBackoffTimeMillis(),
argument.metricsFactory());
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createInitializeTask(argument);
}
@Override
@ -250,24 +244,11 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ThrottlingReporter throttlingReporter =
new ThrottlingReporter(5, argument.shardInfo().shardId());
return new ProcessTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
argument.taskBackoffTimeMillis(),
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
argument.shardDetector(),
throttlingReporter,
input,
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
argument.idleTimeInMilliseconds(),
argument.aggregatorUtil(),
argument.metricsFactory(),
argument.schemaRegistryDecoder(),
argument.leaseCoordinator().leaseStatsRecorder());
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createProcessTask(argument, input);
}
@Override
@ -331,14 +312,12 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
// TODO: notify shutdownrequested
return new ShutdownNotificationTask(
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownNotification(),
argument.shardInfo(),
consumer.shardConsumerArgument().leaseCoordinator());
return taskFactory.createShutdownNotificationTask(argument, consumer);
}
@Override
@ -405,7 +384,10 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return null;
}
@ -483,25 +465,12 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
// TODO: set shutdown reason
return new ShutdownTask(
argument.shardInfo(),
argument.shardDetector(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory(),
input == null ? null : input.childShards(),
argument.streamIdentifier(),
argument.leaseCleanupManager());
return taskFactory.createShutdownTask(argument, consumer, input);
}
@Override
@ -569,7 +538,10 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return null;
}

View file

@ -0,0 +1,47 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
@KinesisClientInternalApi
public interface ConsumerTaskFactory {
/**
* Creates a shutdown task.
*/
ConsumerTask createShutdownTask(ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input);
/**
* Creates a process task.
*/
ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput);
/**
* Creates an initialize task.
*/
ConsumerTask createInitializeTask(ShardConsumerArgument argument);
/**
* Creates a block on parent task.
*/
ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument);
/**
* Creates a shutdown notification task.
*/
ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer);
}

View file

@ -0,0 +1,98 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.ThrottlingReporter;
@KinesisClientInternalApi
public class KinesisConsumerTaskFactory implements ConsumerTaskFactory {
@Override
public ConsumerTask createShutdownTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
return new ShutdownTask(
argument.shardInfo(),
argument.shardDetector(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory(),
input == null ? null : input.childShards(),
argument.streamIdentifier(),
argument.leaseCleanupManager());
}
@Override
public ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput) {
ThrottlingReporter throttlingReporter =
new ThrottlingReporter(5, argument.shardInfo().shardId());
return new ProcessTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
argument.taskBackoffTimeMillis(),
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
argument.shardDetector(),
throttlingReporter,
processRecordsInput,
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
argument.idleTimeInMilliseconds(),
argument.aggregatorUtil(),
argument.metricsFactory(),
argument.schemaRegistryDecoder(),
argument.leaseCoordinator().leaseStatsRecorder());
}
@Override
public ConsumerTask createInitializeTask(ShardConsumerArgument argument) {
return new InitializeTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.checkpoint(),
argument.recordProcessorCheckpointer(),
argument.initialPositionInStream(),
argument.recordsPublisher(),
argument.taskBackoffTimeMillis(),
argument.metricsFactory());
}
@Override
public ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument) {
return new BlockOnParentShardTask(
argument.shardInfo(),
argument.leaseCoordinator().leaseRefresher(),
argument.parentShardPollIntervalMillis());
}
@Override
public ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer) {
return new ShutdownNotificationTask(
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownNotification(),
argument.shardInfo(),
consumer.shardConsumerArgument().leaseCoordinator());
}
}

View file

@ -86,6 +86,8 @@ public class ShardConsumer {
private ProcessRecordsInput shardEndProcessRecordsInput;
private final ConsumerTaskFactory taskFactory;
public ShardConsumer(
RecordsPublisher recordsPublisher,
ExecutorService executorService,
@ -103,7 +105,8 @@ public class ShardConsumer {
ConsumerStates.INITIAL_STATE,
8,
taskExecutionListener,
readTimeoutsToIgnoreBeforeWarning);
readTimeoutsToIgnoreBeforeWarning,
new KinesisConsumerTaskFactory());
}
//
@ -119,6 +122,30 @@ public class ShardConsumer {
int bufferSize,
TaskExecutionListener taskExecutionListener,
int readTimeoutsToIgnoreBeforeWarning) {
this(
recordsPublisher,
executorService,
shardInfo,
logWarningForTaskAfterMillis,
shardConsumerArgument,
initialState,
bufferSize,
taskExecutionListener,
readTimeoutsToIgnoreBeforeWarning,
new KinesisConsumerTaskFactory());
}
public ShardConsumer(
RecordsPublisher recordsPublisher,
ExecutorService executorService,
ShardInfo shardInfo,
Optional<Long> logWarningForTaskAfterMillis,
ShardConsumerArgument shardConsumerArgument,
ConsumerState initialState,
int bufferSize,
TaskExecutionListener taskExecutionListener,
int readTimeoutsToIgnoreBeforeWarning,
ConsumerTaskFactory taskFactory) {
this.recordsPublisher = recordsPublisher;
this.executorService = executorService;
this.shardInfo = shardInfo;
@ -134,6 +161,7 @@ public class ShardConsumer {
if (this.shardInfo.isCompleted()) {
markForShutdown(ShutdownReason.SHARD_END);
}
this.taskFactory = taskFactory;
}
synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) {
@ -345,7 +373,7 @@ public class ShardConsumer {
.taskType(currentState.taskType())
.build();
taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput);
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input);
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input, taskFactory);
if (task != null) {
taskDispatchedAt = Instant.now();
currentTask = task;

View file

@ -127,6 +127,7 @@ public class ConsumerStatesTest {
private long idleTimeInMillis = 1000L;
private Optional<Long> logWarningForTaskAfterMillis = Optional.empty();
private SchemaRegistryDecoder schemaRegistryDecoder = null;
private final ConsumerTaskFactory taskFactory = new KinesisConsumerTaskFactory();
@Before
public void setup() {
@ -177,7 +178,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState();
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, taskWith(BlockOnParentShardTask.class, ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(
@ -209,7 +210,7 @@ public class ConsumerStatesTest {
@Test
public void initializingStateTest() {
ConsumerState state = ShardConsumerState.INITIALIZING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, initTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, initTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -242,7 +243,7 @@ public class ConsumerStatesTest {
public void processingStateTestSynchronous() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -274,7 +275,7 @@ public class ConsumerStatesTest {
public void processingStateTestAsynchronous() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -306,7 +307,7 @@ public class ConsumerStatesTest {
public void processingStateRecordsFetcher() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -339,7 +340,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.consumerState();
consumer.gracefulShutdown(shutdownNotification);
ConsumerTask task = state.createTask(argument, consumer, null);
ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(
task,
@ -373,7 +374,7 @@ public class ConsumerStatesTest {
public void shutdownRequestCompleteStateTest() {
ConsumerState state = ConsumerStates.SHUTDOWN_REQUEST_COMPLETION_STATE;
assertThat(state.createTask(argument, consumer, null), nullValue());
assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue());
assertThat(state.successTransition(), equalTo(state));
@ -409,7 +410,7 @@ public class ConsumerStatesTest {
childShards.add(leftChild);
childShards.add(rightChild);
when(processRecordsInput.childShards()).thenReturn(childShards);
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput);
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput, taskFactory);
assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(
@ -443,7 +444,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState();
assertThat(state.createTask(argument, consumer, null), nullValue());
assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue());
assertThat(state.successTransition(), equalTo(state));
for (ShutdownReason reason : ShutdownReason.values()) {

View file

@ -332,7 +332,7 @@ public class ShardConsumerTest {
verify(cache.subscription, times(3)).request(anyLong());
verify(cache.subscription).cancel();
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any());
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
@ -394,7 +394,7 @@ public class ShardConsumerTest {
verify(cache.subscription, times(1)).request(anyLong());
verify(cache.subscription).cancel();
verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any());
verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
@ -437,14 +437,14 @@ public class ShardConsumerTest {
cache.publish();
awaitAndResetBarrier(taskCallBarrier);
verify(processingState).createTask(any(), any(), any());
verify(processingState).createTask(any(), any(), any(), any());
verify(processingTask).call();
cache.awaitRequest();
cache.publish();
awaitAndResetBarrier(taskCallBarrier);
verify(processingState, times(2)).createTask(any(), any(), any());
verify(processingState, times(2)).createTask(any(), any(), any(), any());
verify(processingTask, times(2)).call();
cache.awaitRequest();
@ -460,7 +460,7 @@ public class ShardConsumerTest {
shutdownComplete = consumer.shutdownComplete().get();
} while (!shutdownComplete);
verify(processingState, times(3)).createTask(any(), any(), any());
verify(processingState, times(3)).createTask(any(), any(), any(), any());
verify(processingTask, times(3)).call();
verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
@ -487,7 +487,7 @@ public class ShardConsumerTest {
public final void testInitializationStateUponFailure() throws Exception {
final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any()))
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any(), any()))
.thenReturn(initializeTask);
when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad")));
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
@ -505,7 +505,7 @@ public class ShardConsumerTest {
awaitAndResetBarrier(taskBarrier);
}
verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any());
verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(initialState, never()).successTransition();
verify(initialState, never()).shutdownTransition(any());
}
@ -665,7 +665,7 @@ public class ShardConsumerTest {
public void testErrorThrowableInInitialization() throws Exception {
final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask);
when(initialState.createTask(any(), any(), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> {
throw new Error("Error");
@ -692,13 +692,13 @@ public class ShardConsumerTest {
mockSuccessfulProcessing(taskBarrier);
when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState);
when(shutdownRequestedState.createTask(any(), any(), any())).thenReturn(shutdownRequestedTask);
when(shutdownRequestedState.createTask(any(), any(), any(), any())).thenReturn(shutdownRequestedTask);
when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION);
when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null));
when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED)))
.thenReturn(shutdownRequestedAwaitState);
when(shutdownRequestedAwaitState.createTask(any(), any(), any())).thenReturn(null);
when(shutdownRequestedAwaitState.createTask(any(), any(), any(), any())).thenReturn(null);
when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST)))
.thenReturn(shutdownState);
when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE);
@ -733,11 +733,11 @@ public class ShardConsumerTest {
shutdownComplete = consumer.shutdownComplete().get();
assertTrue(shutdownComplete);
verify(processingState, times(2)).createTask(any(), any(), any());
verify(processingState, times(2)).createTask(any(), any(), any(), any());
verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(shutdownRequestedState).createTask(any(), any(), any());
verify(shutdownRequestedState).createTask(any(), any(), any(), any());
verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED));
verify(shutdownRequestedAwaitState).createTask(any(), any(), any());
verify(shutdownRequestedAwaitState).createTask(any(), any(), any(), any());
verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
@ -948,7 +948,7 @@ public class ShardConsumerTest {
when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
final ConsumerTask mockTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockTask);
when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockTask);
// Simulate successful BlockedOnParent task execution
// and successful Initialize task execution
when(mockTask.call()).thenReturn(new TaskResult(false));
@ -993,7 +993,7 @@ public class ShardConsumerTest {
reset(mockState);
when(mockState.taskType()).thenReturn(TaskType.PROCESS);
final ConsumerTask mockProcessTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask);
when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockProcessTask);
when(mockProcessTask.call()).then(input -> {
// first we want to wait for subscribe to be called,
// but we cannot control the timing, so wait for 10 seconds
@ -1045,7 +1045,8 @@ public class ShardConsumerTest {
}
private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) {
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(shutdownTask);
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(shutdownTask);
when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN);
when(shutdownTask.call()).thenAnswer(i -> {
awaitBarrier(taskArriveBarrier);
@ -1063,7 +1064,7 @@ public class ShardConsumerTest {
}
private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(processingState.createTask(eq(shardConsumerArgument), any(), any()))
when(processingState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(processingTask);
when(processingState.taskType()).thenReturn(TaskType.PROCESS);
when(processingTask.taskType()).thenReturn(TaskType.PROCESS);
@ -1088,7 +1089,8 @@ public class ShardConsumerTest {
}
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask);
when(initialState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> {
@ -1107,7 +1109,7 @@ public class ShardConsumerTest {
}
private void mockSuccessfulUnblockOnParents() {
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any()))
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(blockedOnParentsTask);
when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult);