Merge branch 'awslabs:master' into master

This commit is contained in:
vincentvilo-aws 2025-03-11 11:08:31 -07:00 committed by GitHub
commit 31cba5f140
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 268 additions and 99 deletions

View file

@ -55,6 +55,14 @@ To make it easier for developers to write record processors in other languages,
## Using the KCL ## Using the KCL
The recommended way to use the KCL for Java is to consume it from Maven. The recommended way to use the KCL for Java is to consume it from Maven.
## 🚨Important: Do not use AWS SDK for Java versions 2.27.19 to 2.27.23 with KCL 3.x
When using KCL 3.x with AWS SDK for Java versions 2.27.19 through 2.27.23, you may encounter the following DynamoDB exception:
```
software.amazon.awssdk.services.dynamodb.model.DynamoDbException: The document path provided in the update expression is invalid for update (Service: DynamoDb, Status Code: 400, Request ID: xxx)
```
This error occurs due to [a known issue](https://github.com/aws/aws-sdk-java-v2/issues/5584) in the AWS SDK for Java that affects the DynamoDB metadata table managed by KCL 3.x. The issue was introduced in version 2.27.19 and impacts all versions up to 2.27.23. The issue has been resolved in the AWS SDK for Java version 2.27.24. For optimal performance and stability, we recommend upgrading to version 2.28.0 or later.
### Version 3.x ### Version 3.x
``` xml ``` xml
<dependency> <dependency>
@ -70,7 +78,7 @@ The recommended way to use the KCL for Java is to consume it from Maven.
<dependency> <dependency>
<groupId>software.amazon.kinesis</groupId> <groupId>software.amazon.kinesis</groupId>
<artifactId>amazon-kinesis-client</artifactId> <artifactId>amazon-kinesis-client</artifactId>
<version>2.6.0</version> <version>2.6.1</version>
</dependency> </dependency>
``` ```
@ -127,4 +135,4 @@ By participating through these channels, you play a vital role in shaping the fu
[migration-guide]: https://docs.aws.amazon.com/streams/latest/dev/kcl-migration-from-previous-versions [migration-guide]: https://docs.aws.amazon.com/streams/latest/dev/kcl-migration-from-previous-versions
[kcl-sample]: https://docs.aws.amazon.com/streams/latest/dev/kcl-example-code [kcl-sample]: https://docs.aws.amazon.com/streams/latest/dev/kcl-example-code
[kcl-aws-doc]: https://docs.aws.amazon.com/streams/latest/dev/kcl.html [kcl-aws-doc]: https://docs.aws.amazon.com/streams/latest/dev/kcl.html
[giving-feedback]: https://github.com/awslabs/amazon-kinesis-client?tab=readme-ov-file#giving-feedback [giving-feedback]: https://github.com/awslabs/amazon-kinesis-client?tab=readme-ov-file#giving-feedback

View file

@ -171,7 +171,7 @@
<dependency> <dependency>
<groupId>io.netty</groupId> <groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId> <artifactId>netty-handler</artifactId>
<version>4.1.108.Final</version> <version>4.1.118.Final</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.google.code.findbugs</groupId> <groupId>com.google.code.findbugs</groupId>
@ -181,7 +181,7 @@
<dependency> <dependency>
<groupId>com.fasterxml.jackson.core</groupId> <groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId> <artifactId>jackson-databind</artifactId>
<version>2.10.1</version> <version>2.12.7.1</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.reactivestreams</groupId> <groupId>org.reactivestreams</groupId>

View file

@ -87,6 +87,7 @@ import software.amazon.kinesis.leases.dynamodb.DynamoDBMultiStreamLeaseSerialize
import software.amazon.kinesis.leases.exceptions.DependencyException; import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.leases.exceptions.InvalidStateException; import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException; import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
import software.amazon.kinesis.lifecycle.ConsumerTaskFactory;
import software.amazon.kinesis.lifecycle.LifecycleConfig; import software.amazon.kinesis.lifecycle.LifecycleConfig;
import software.amazon.kinesis.lifecycle.ShardConsumer; import software.amazon.kinesis.lifecycle.ShardConsumer;
import software.amazon.kinesis.lifecycle.ShardConsumerArgument; import software.amazon.kinesis.lifecycle.ShardConsumerArgument;
@ -188,6 +189,7 @@ public class Scheduler implements Runnable {
private final SchemaRegistryDecoder schemaRegistryDecoder; private final SchemaRegistryDecoder schemaRegistryDecoder;
private final DeletedStreamListProvider deletedStreamListProvider; private final DeletedStreamListProvider deletedStreamListProvider;
private final ConsumerTaskFactory taskFactory;
@Getter(AccessLevel.NONE) @Getter(AccessLevel.NONE)
private final MigrationStateMachine migrationStateMachine; private final MigrationStateMachine migrationStateMachine;
@ -371,6 +373,7 @@ public class Scheduler implements Runnable {
this.schemaRegistryDecoder = this.retrievalConfig.glueSchemaRegistryDeserializer() == null this.schemaRegistryDecoder = this.retrievalConfig.glueSchemaRegistryDeserializer() == null
? null ? null
: new SchemaRegistryDecoder(this.retrievalConfig.glueSchemaRegistryDeserializer()); : new SchemaRegistryDecoder(this.retrievalConfig.glueSchemaRegistryDeserializer());
this.taskFactory = leaseManagementConfig().consumerTaskFactory();
} }
/** /**

View file

@ -537,8 +537,8 @@ public final class LeaseAssignmentManager {
.filter(workerMetrics -> !workerMetrics.isValidWorkerMetric()) .filter(workerMetrics -> !workerMetrics.isValidWorkerMetric())
.map(WorkerMetricStats::getWorkerId) .map(WorkerMetricStats::getWorkerId)
.collect(Collectors.toList()); .collect(Collectors.toList());
log.warn("List of workerIds with invalid entries : {}", listOfWorkerIdOfInvalidWorkerMetricsEntry);
if (!listOfWorkerIdOfInvalidWorkerMetricsEntry.isEmpty()) { if (!listOfWorkerIdOfInvalidWorkerMetricsEntry.isEmpty()) {
log.warn("List of workerIds with invalid entries : {}", listOfWorkerIdOfInvalidWorkerMetricsEntry);
metricsScope.addData( metricsScope.addData(
"NumWorkersWithInvalidEntry", "NumWorkersWithInvalidEntry",
listOfWorkerIdOfInvalidWorkerMetricsEntry.size(), listOfWorkerIdOfInvalidWorkerMetricsEntry.size(),
@ -567,8 +567,8 @@ public final class LeaseAssignmentManager {
final Map.Entry<List<Lease>, List<String>> leaseListResponse = leaseListFuture.join(); final Map.Entry<List<Lease>, List<String>> leaseListResponse = leaseListFuture.join();
this.leaseList = leaseListResponse.getKey(); this.leaseList = leaseListResponse.getKey();
log.warn("Leases that failed deserialization : {}", leaseListResponse.getValue());
if (!leaseListResponse.getValue().isEmpty()) { if (!leaseListResponse.getValue().isEmpty()) {
log.warn("Leases that failed deserialization : {}", leaseListResponse.getValue());
MetricsUtil.addCount( MetricsUtil.addCount(
metricsScope, metricsScope,
"LeaseDeserializationFailureCount", "LeaseDeserializationFailureCount",

View file

@ -45,6 +45,8 @@ import software.amazon.kinesis.common.StreamConfig;
import software.amazon.kinesis.leases.dynamodb.DynamoDBLeaseManagementFactory; import software.amazon.kinesis.leases.dynamodb.DynamoDBLeaseManagementFactory;
import software.amazon.kinesis.leases.dynamodb.DynamoDBLeaseSerializer; import software.amazon.kinesis.leases.dynamodb.DynamoDBLeaseSerializer;
import software.amazon.kinesis.leases.dynamodb.TableCreatorCallback; import software.amazon.kinesis.leases.dynamodb.TableCreatorCallback;
import software.amazon.kinesis.lifecycle.ConsumerTaskFactory;
import software.amazon.kinesis.lifecycle.KinesisConsumerTaskFactory;
import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsFactory;
import software.amazon.kinesis.metrics.NullMetricsFactory; import software.amazon.kinesis.metrics.NullMetricsFactory;
import software.amazon.kinesis.worker.metric.WorkerMetric; import software.amazon.kinesis.worker.metric.WorkerMetric;
@ -215,6 +217,8 @@ public class LeaseManagementConfig {
private BillingMode billingMode = BillingMode.PAY_PER_REQUEST; private BillingMode billingMode = BillingMode.PAY_PER_REQUEST;
private ConsumerTaskFactory consumerTaskFactory = new KinesisConsumerTaskFactory();
private WorkerUtilizationAwareAssignmentConfig workerUtilizationAwareAssignmentConfig = private WorkerUtilizationAwareAssignmentConfig workerUtilizationAwareAssignmentConfig =
new WorkerUtilizationAwareAssignmentConfig(); new WorkerUtilizationAwareAssignmentConfig();

View file

@ -34,13 +34,19 @@ interface ConsumerState {
* the consumer to use build the task, or execute state. * the consumer to use build the task, or execute state.
* @param input * @param input
* the process input received, this may be null if it's a control message * the process input received, this may be null if it's a control message
* @param taskFactory
* a factory for creating tasks
* @return a valid task for this state or null if there is no task required. * @return a valid task for this state or null if there is no task required.
*/ */
ConsumerTask createTask(ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input); ConsumerTask createTask(
ShardConsumerArgument consumerArgument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory);
/** /**
* Provides the next state of the consumer upon success of the task return by * Provides the next state of the consumer upon success of the task return by
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput)}. * {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput, ConsumerTaskFactory)}.
* *
* @return the next state that the consumer should transition to, this may be the same object as the current * @return the next state that the consumer should transition to, this may be the same object as the current
* state. * state.

View file

@ -17,7 +17,6 @@ package software.amazon.kinesis.lifecycle;
import lombok.Getter; import lombok.Getter;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.ThrottlingReporter;
/** /**
* Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks, * Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks,
@ -121,11 +120,11 @@ class ConsumerStates {
@Override @Override
public ConsumerTask createTask( public ConsumerTask createTask(
ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input) { ShardConsumerArgument consumerArgument,
return new BlockOnParentShardTask( ShardConsumer consumer,
consumerArgument.shardInfo(), ProcessRecordsInput input,
consumerArgument.leaseCoordinator().leaseRefresher(), ConsumerTaskFactory taskFactory) {
consumerArgument.parentShardPollIntervalMillis()); return taskFactory.createBlockOnParentTask(consumerArgument);
} }
@Override @Override
@ -187,16 +186,11 @@ class ConsumerStates {
@Override @Override
public ConsumerTask createTask( public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { ShardConsumerArgument argument,
return new InitializeTask( ShardConsumer consumer,
argument.shardInfo(), ProcessRecordsInput input,
argument.shardRecordProcessor(), ConsumerTaskFactory taskFactory) {
argument.checkpoint(), return taskFactory.createInitializeTask(argument);
argument.recordProcessorCheckpointer(),
argument.initialPositionInStream(),
argument.recordsPublisher(),
argument.taskBackoffTimeMillis(),
argument.metricsFactory());
} }
@Override @Override
@ -250,24 +244,11 @@ class ConsumerStates {
@Override @Override
public ConsumerTask createTask( public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { ShardConsumerArgument argument,
ThrottlingReporter throttlingReporter = ShardConsumer consumer,
new ThrottlingReporter(5, argument.shardInfo().shardId()); ProcessRecordsInput input,
return new ProcessTask( ConsumerTaskFactory taskFactory) {
argument.shardInfo(), return taskFactory.createProcessTask(argument, input);
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
argument.taskBackoffTimeMillis(),
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
argument.shardDetector(),
throttlingReporter,
input,
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
argument.idleTimeInMilliseconds(),
argument.aggregatorUtil(),
argument.metricsFactory(),
argument.schemaRegistryDecoder(),
argument.leaseCoordinator().leaseStatsRecorder());
} }
@Override @Override
@ -331,14 +312,12 @@ class ConsumerStates {
@Override @Override
public ConsumerTask createTask( public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
// TODO: notify shutdownrequested // TODO: notify shutdownrequested
return new ShutdownNotificationTask( return taskFactory.createShutdownNotificationTask(argument, consumer);
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownNotification(),
argument.shardInfo(),
consumer.shardConsumerArgument().leaseCoordinator());
} }
@Override @Override
@ -405,7 +384,10 @@ class ConsumerStates {
@Override @Override
public ConsumerTask createTask( public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return null; return null;
} }
@ -483,25 +465,12 @@ class ConsumerStates {
@Override @Override
public ConsumerTask createTask( public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
// TODO: set shutdown reason // TODO: set shutdown reason
return new ShutdownTask( return taskFactory.createShutdownTask(argument, consumer, input);
argument.shardInfo(),
argument.shardDetector(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory(),
input == null ? null : input.childShards(),
argument.streamIdentifier(),
argument.leaseCleanupManager());
} }
@Override @Override
@ -569,7 +538,10 @@ class ConsumerStates {
@Override @Override
public ConsumerTask createTask( public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) { ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return null; return null;
} }

View file

@ -0,0 +1,47 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
@KinesisClientInternalApi
public interface ConsumerTaskFactory {
/**
* Creates a shutdown task.
*/
ConsumerTask createShutdownTask(ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input);
/**
* Creates a process task.
*/
ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput);
/**
* Creates an initialize task.
*/
ConsumerTask createInitializeTask(ShardConsumerArgument argument);
/**
* Creates a block on parent task.
*/
ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument);
/**
* Creates a shutdown notification task.
*/
ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer);
}

View file

@ -0,0 +1,98 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package software.amazon.kinesis.lifecycle;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.ThrottlingReporter;
@KinesisClientInternalApi
public class KinesisConsumerTaskFactory implements ConsumerTaskFactory {
@Override
public ConsumerTask createShutdownTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
return new ShutdownTask(
argument.shardInfo(),
argument.shardDetector(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory(),
input == null ? null : input.childShards(),
argument.streamIdentifier(),
argument.leaseCleanupManager());
}
@Override
public ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput) {
ThrottlingReporter throttlingReporter =
new ThrottlingReporter(5, argument.shardInfo().shardId());
return new ProcessTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
argument.taskBackoffTimeMillis(),
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
argument.shardDetector(),
throttlingReporter,
processRecordsInput,
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
argument.idleTimeInMilliseconds(),
argument.aggregatorUtil(),
argument.metricsFactory(),
argument.schemaRegistryDecoder(),
argument.leaseCoordinator().leaseStatsRecorder());
}
@Override
public ConsumerTask createInitializeTask(ShardConsumerArgument argument) {
return new InitializeTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.checkpoint(),
argument.recordProcessorCheckpointer(),
argument.initialPositionInStream(),
argument.recordsPublisher(),
argument.taskBackoffTimeMillis(),
argument.metricsFactory());
}
@Override
public ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument) {
return new BlockOnParentShardTask(
argument.shardInfo(),
argument.leaseCoordinator().leaseRefresher(),
argument.parentShardPollIntervalMillis());
}
@Override
public ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer) {
return new ShutdownNotificationTask(
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownNotification(),
argument.shardInfo(),
consumer.shardConsumerArgument().leaseCoordinator());
}
}

View file

@ -86,6 +86,8 @@ public class ShardConsumer {
private ProcessRecordsInput shardEndProcessRecordsInput; private ProcessRecordsInput shardEndProcessRecordsInput;
private final ConsumerTaskFactory taskFactory;
public ShardConsumer( public ShardConsumer(
RecordsPublisher recordsPublisher, RecordsPublisher recordsPublisher,
ExecutorService executorService, ExecutorService executorService,
@ -103,7 +105,8 @@ public class ShardConsumer {
ConsumerStates.INITIAL_STATE, ConsumerStates.INITIAL_STATE,
8, 8,
taskExecutionListener, taskExecutionListener,
readTimeoutsToIgnoreBeforeWarning); readTimeoutsToIgnoreBeforeWarning,
new KinesisConsumerTaskFactory());
} }
// //
@ -119,6 +122,30 @@ public class ShardConsumer {
int bufferSize, int bufferSize,
TaskExecutionListener taskExecutionListener, TaskExecutionListener taskExecutionListener,
int readTimeoutsToIgnoreBeforeWarning) { int readTimeoutsToIgnoreBeforeWarning) {
this(
recordsPublisher,
executorService,
shardInfo,
logWarningForTaskAfterMillis,
shardConsumerArgument,
initialState,
bufferSize,
taskExecutionListener,
readTimeoutsToIgnoreBeforeWarning,
new KinesisConsumerTaskFactory());
}
public ShardConsumer(
RecordsPublisher recordsPublisher,
ExecutorService executorService,
ShardInfo shardInfo,
Optional<Long> logWarningForTaskAfterMillis,
ShardConsumerArgument shardConsumerArgument,
ConsumerState initialState,
int bufferSize,
TaskExecutionListener taskExecutionListener,
int readTimeoutsToIgnoreBeforeWarning,
ConsumerTaskFactory taskFactory) {
this.recordsPublisher = recordsPublisher; this.recordsPublisher = recordsPublisher;
this.executorService = executorService; this.executorService = executorService;
this.shardInfo = shardInfo; this.shardInfo = shardInfo;
@ -134,6 +161,7 @@ public class ShardConsumer {
if (this.shardInfo.isCompleted()) { if (this.shardInfo.isCompleted()) {
markForShutdown(ShutdownReason.SHARD_END); markForShutdown(ShutdownReason.SHARD_END);
} }
this.taskFactory = taskFactory;
} }
synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) { synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) {
@ -345,7 +373,7 @@ public class ShardConsumer {
.taskType(currentState.taskType()) .taskType(currentState.taskType())
.build(); .build();
taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput); taskExecutionListener.beforeTaskExecution(taskExecutionListenerInput);
ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input); ConsumerTask task = currentState.createTask(shardConsumerArgument, ShardConsumer.this, input, taskFactory);
if (task != null) { if (task != null) {
taskDispatchedAt = Instant.now(); taskDispatchedAt = Instant.now();
currentTask = task; currentTask = task;

View file

@ -127,6 +127,7 @@ public class ConsumerStatesTest {
private long idleTimeInMillis = 1000L; private long idleTimeInMillis = 1000L;
private Optional<Long> logWarningForTaskAfterMillis = Optional.empty(); private Optional<Long> logWarningForTaskAfterMillis = Optional.empty();
private SchemaRegistryDecoder schemaRegistryDecoder = null; private SchemaRegistryDecoder schemaRegistryDecoder = null;
private final ConsumerTaskFactory taskFactory = new KinesisConsumerTaskFactory();
@Before @Before
public void setup() { public void setup() {
@ -177,7 +178,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState(); ConsumerState state = ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState();
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher); when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
ConsumerTask task = state.createTask(argument, consumer, null); ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, taskWith(BlockOnParentShardTask.class, ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, taskWith(BlockOnParentShardTask.class, ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat( assertThat(
@ -209,7 +210,7 @@ public class ConsumerStatesTest {
@Test @Test
public void initializingStateTest() { public void initializingStateTest() {
ConsumerState state = ShardConsumerState.INITIALIZING.consumerState(); ConsumerState state = ShardConsumerState.INITIALIZING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null); ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, initTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, initTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, initTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor))); assertThat(task, initTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -242,7 +243,7 @@ public class ConsumerStatesTest {
public void processingStateTestSynchronous() { public void processingStateTestSynchronous() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder); when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState(); ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null); ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor))); assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -274,7 +275,7 @@ public class ConsumerStatesTest {
public void processingStateTestAsynchronous() { public void processingStateTestAsynchronous() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder); when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState(); ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null); ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor))); assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -306,7 +307,7 @@ public class ConsumerStatesTest {
public void processingStateRecordsFetcher() { public void processingStateRecordsFetcher() {
when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder); when(leaseCoordinator.leaseStatsRecorder()).thenReturn(leaseStatsRecorder);
ConsumerState state = ShardConsumerState.PROCESSING.consumerState(); ConsumerState state = ShardConsumerState.PROCESSING.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null); ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor))); assertThat(task, procTask(ShardRecordProcessor.class, "shardRecordProcessor", equalTo(shardRecordProcessor)));
@ -339,7 +340,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.consumerState(); ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.consumerState();
consumer.gracefulShutdown(shutdownNotification); consumer.gracefulShutdown(shutdownNotification);
ConsumerTask task = state.createTask(argument, consumer, null); ConsumerTask task = state.createTask(argument, consumer, null, taskFactory);
assertThat( assertThat(
task, task,
@ -373,7 +374,7 @@ public class ConsumerStatesTest {
public void shutdownRequestCompleteStateTest() { public void shutdownRequestCompleteStateTest() {
ConsumerState state = ConsumerStates.SHUTDOWN_REQUEST_COMPLETION_STATE; ConsumerState state = ConsumerStates.SHUTDOWN_REQUEST_COMPLETION_STATE;
assertThat(state.createTask(argument, consumer, null), nullValue()); assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue());
assertThat(state.successTransition(), equalTo(state)); assertThat(state.successTransition(), equalTo(state));
@ -409,7 +410,7 @@ public class ConsumerStatesTest {
childShards.add(leftChild); childShards.add(leftChild);
childShards.add(rightChild); childShards.add(rightChild);
when(processRecordsInput.childShards()).thenReturn(childShards); when(processRecordsInput.childShards()).thenReturn(childShards);
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput); ConsumerTask task = state.createTask(argument, consumer, processRecordsInput, taskFactory);
assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat( assertThat(
@ -443,7 +444,7 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState(); ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState();
assertThat(state.createTask(argument, consumer, null), nullValue()); assertThat(state.createTask(argument, consumer, null, taskFactory), nullValue());
assertThat(state.successTransition(), equalTo(state)); assertThat(state.successTransition(), equalTo(state));
for (ShutdownReason reason : ShutdownReason.values()) { for (ShutdownReason reason : ShutdownReason.values()) {

View file

@ -332,7 +332,7 @@ public class ShardConsumerTest {
verify(cache.subscription, times(3)).request(anyLong()); verify(cache.subscription, times(3)).request(anyLong());
verify(cache.subscription).cancel(); verify(cache.subscription).cancel();
verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any()); verify(processingState, times(2)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput); verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
@ -394,7 +394,7 @@ public class ShardConsumerTest {
verify(cache.subscription, times(1)).request(anyLong()); verify(cache.subscription, times(1)).request(anyLong());
verify(cache.subscription).cancel(); verify(cache.subscription).cancel();
verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any()); verify(processingState, times(1)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(processTaskInput);
verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(shutdownTaskInput);
@ -437,14 +437,14 @@ public class ShardConsumerTest {
cache.publish(); cache.publish();
awaitAndResetBarrier(taskCallBarrier); awaitAndResetBarrier(taskCallBarrier);
verify(processingState).createTask(any(), any(), any()); verify(processingState).createTask(any(), any(), any(), any());
verify(processingTask).call(); verify(processingTask).call();
cache.awaitRequest(); cache.awaitRequest();
cache.publish(); cache.publish();
awaitAndResetBarrier(taskCallBarrier); awaitAndResetBarrier(taskCallBarrier);
verify(processingState, times(2)).createTask(any(), any(), any()); verify(processingState, times(2)).createTask(any(), any(), any(), any());
verify(processingTask, times(2)).call(); verify(processingTask, times(2)).call();
cache.awaitRequest(); cache.awaitRequest();
@ -460,7 +460,7 @@ public class ShardConsumerTest {
shutdownComplete = consumer.shutdownComplete().get(); shutdownComplete = consumer.shutdownComplete().get();
} while (!shutdownComplete); } while (!shutdownComplete);
verify(processingState, times(3)).createTask(any(), any(), any()); verify(processingState, times(3)).createTask(any(), any(), any(), any());
verify(processingTask, times(3)).call(); verify(processingTask, times(3)).call();
verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); verify(processingState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); verify(shutdownState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
@ -487,7 +487,7 @@ public class ShardConsumerTest {
public final void testInitializationStateUponFailure() throws Exception { public final void testInitializationStateUponFailure() throws Exception {
final ShardConsumer consumer = createShardConsumer(recordsPublisher); final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any())) when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any(), any()))
.thenReturn(initializeTask); .thenReturn(initializeTask);
when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad"))); when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad")));
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
@ -505,7 +505,7 @@ public class ShardConsumerTest {
awaitAndResetBarrier(taskBarrier); awaitAndResetBarrier(taskBarrier);
} }
verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any()); verify(initialState, times(5)).createTask(eq(shardConsumerArgument), eq(consumer), any(), any());
verify(initialState, never()).successTransition(); verify(initialState, never()).successTransition();
verify(initialState, never()).shutdownTransition(any()); verify(initialState, never()).shutdownTransition(any());
} }
@ -665,7 +665,7 @@ public class ShardConsumerTest {
public void testErrorThrowableInInitialization() throws Exception { public void testErrorThrowableInInitialization() throws Exception {
final ShardConsumer consumer = createShardConsumer(recordsPublisher); final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask); when(initialState.createTask(any(), any(), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE); when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> { when(initializeTask.call()).thenAnswer(i -> {
throw new Error("Error"); throw new Error("Error");
@ -692,13 +692,13 @@ public class ShardConsumerTest {
mockSuccessfulProcessing(taskBarrier); mockSuccessfulProcessing(taskBarrier);
when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState); when(processingState.shutdownTransition(eq(ShutdownReason.REQUESTED))).thenReturn(shutdownRequestedState);
when(shutdownRequestedState.createTask(any(), any(), any())).thenReturn(shutdownRequestedTask); when(shutdownRequestedState.createTask(any(), any(), any(), any())).thenReturn(shutdownRequestedTask);
when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION); when(shutdownRequestedState.taskType()).thenReturn(TaskType.SHUTDOWN_NOTIFICATION);
when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null)); when(shutdownRequestedTask.call()).thenReturn(new TaskResult(null));
when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED))) when(shutdownRequestedState.shutdownTransition(eq(ShutdownReason.REQUESTED)))
.thenReturn(shutdownRequestedAwaitState); .thenReturn(shutdownRequestedAwaitState);
when(shutdownRequestedAwaitState.createTask(any(), any(), any())).thenReturn(null); when(shutdownRequestedAwaitState.createTask(any(), any(), any(), any())).thenReturn(null);
when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST))) when(shutdownRequestedAwaitState.shutdownTransition(eq(ShutdownReason.LEASE_LOST)))
.thenReturn(shutdownState); .thenReturn(shutdownState);
when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE); when(shutdownRequestedAwaitState.taskType()).thenReturn(TaskType.SHUTDOWN_COMPLETE);
@ -733,11 +733,11 @@ public class ShardConsumerTest {
shutdownComplete = consumer.shutdownComplete().get(); shutdownComplete = consumer.shutdownComplete().get();
assertTrue(shutdownComplete); assertTrue(shutdownComplete);
verify(processingState, times(2)).createTask(any(), any(), any()); verify(processingState, times(2)).createTask(any(), any(), any(), any());
verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(shutdownRequestedState).createTask(any(), any(), any()); verify(shutdownRequestedState).createTask(any(), any(), any(), any());
verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED)); verify(shutdownRequestedState).shutdownTransition(eq(ShutdownReason.REQUESTED));
verify(shutdownRequestedAwaitState).createTask(any(), any(), any()); verify(shutdownRequestedAwaitState).createTask(any(), any(), any(), any());
verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); verify(shutdownRequestedAwaitState).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput); verify(taskExecutionListener, times(2)).beforeTaskExecution(processTaskInput);
@ -948,7 +948,7 @@ public class ShardConsumerTest {
when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS); when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS); when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
final ConsumerTask mockTask = mock(ConsumerTask.class); final ConsumerTask mockTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockTask); when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockTask);
// Simulate successful BlockedOnParent task execution // Simulate successful BlockedOnParent task execution
// and successful Initialize task execution // and successful Initialize task execution
when(mockTask.call()).thenReturn(new TaskResult(false)); when(mockTask.call()).thenReturn(new TaskResult(false));
@ -993,7 +993,7 @@ public class ShardConsumerTest {
reset(mockState); reset(mockState);
when(mockState.taskType()).thenReturn(TaskType.PROCESS); when(mockState.taskType()).thenReturn(TaskType.PROCESS);
final ConsumerTask mockProcessTask = mock(ConsumerTask.class); final ConsumerTask mockProcessTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask); when(mockState.createTask(any(), any(), any(), any())).thenReturn(mockProcessTask);
when(mockProcessTask.call()).then(input -> { when(mockProcessTask.call()).then(input -> {
// first we want to wait for subscribe to be called, // first we want to wait for subscribe to be called,
// but we cannot control the timing, so wait for 10 seconds // but we cannot control the timing, so wait for 10 seconds
@ -1045,7 +1045,8 @@ public class ShardConsumerTest {
} }
private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) { private void mockSuccessfulShutdown(CyclicBarrier taskArriveBarrier, CyclicBarrier taskDepartBarrier) {
when(shutdownState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(shutdownTask); when(shutdownState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(shutdownTask);
when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN); when(shutdownState.taskType()).thenReturn(TaskType.SHUTDOWN);
when(shutdownTask.call()).thenAnswer(i -> { when(shutdownTask.call()).thenAnswer(i -> {
awaitBarrier(taskArriveBarrier); awaitBarrier(taskArriveBarrier);
@ -1063,7 +1064,7 @@ public class ShardConsumerTest {
} }
private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) { private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(processingState.createTask(eq(shardConsumerArgument), any(), any())) when(processingState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(processingTask); .thenReturn(processingTask);
when(processingState.taskType()).thenReturn(TaskType.PROCESS); when(processingState.taskType()).thenReturn(TaskType.PROCESS);
when(processingTask.taskType()).thenReturn(TaskType.PROCESS); when(processingTask.taskType()).thenReturn(TaskType.PROCESS);
@ -1088,7 +1089,8 @@ public class ShardConsumerTest {
} }
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) { private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask); when(initialState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE); when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.call()).thenAnswer(i -> { when(initializeTask.call()).thenAnswer(i -> {
@ -1107,7 +1109,7 @@ public class ShardConsumerTest {
} }
private void mockSuccessfulUnblockOnParents() { private void mockSuccessfulUnblockOnParents() {
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any())) when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any(), any()))
.thenReturn(blockedOnParentsTask); .thenReturn(blockedOnParentsTask);
when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS); when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult); when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult);