Merge pull request #10 from ychunxue/ltr_base_ShardEnd

ShardEnd shard sync with Child Shards
This commit is contained in:
ychunxue 2020-04-09 10:50:44 -07:00 committed by GitHub
commit b1a3d215d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 467 additions and 166 deletions

View file

@ -39,6 +39,7 @@ import org.apache.commons.lang3.StringUtils;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.services.kinesis.model.ShardFilter;
import software.amazon.awssdk.services.kinesis.model.ShardFilterType;
@ -763,6 +764,41 @@ public class HierarchicalShardSyncer {
}
}
public synchronized Lease createLeaseForChildShard(final ChildShard childShard, final StreamIdentifier streamIdentifier) throws InvalidStateException {
final MultiStreamArgs multiStreamArgs = new MultiStreamArgs(isMultiStreamMode, streamIdentifier);
return multiStreamArgs.isMultiStreamMode() ? newKCLMultiStreamLeaseForChildShard(childShard, streamIdentifier)
: newKCLLeaseForChildShard(childShard);
}
private static Lease newKCLLeaseForChildShard(final ChildShard childShard) throws InvalidStateException {
Lease newLease = new Lease();
newLease.leaseKey(childShard.shardId());
if (!CollectionUtils.isNullOrEmpty(childShard.parentShards())) {
newLease.parentShardIds(childShard.parentShards());
} else {
throw new InvalidStateException("Unable to populate new lease for child shard " + childShard.shardId() + "because parent shards cannot be found.");
}
newLease.checkpoint(ExtendedSequenceNumber.TRIM_HORIZON);
newLease.ownerSwitchesSinceCheckpoint(0L);
return newLease;
}
private static Lease newKCLMultiStreamLeaseForChildShard(final ChildShard childShard, final StreamIdentifier streamIdentifier) throws InvalidStateException {
MultiStreamLease newLease = new MultiStreamLease();
newLease.leaseKey(MultiStreamLease.getLeaseKey(streamIdentifier.serialize(), childShard.shardId()));
if (!CollectionUtils.isNullOrEmpty(childShard.parentShards())) {
newLease.parentShardIds(childShard.parentShards());
} else {
throw new InvalidStateException("Unable to populate new lease for child shard " + childShard.shardId() + "because parent shards cannot be found.");
}
newLease.checkpoint(ExtendedSequenceNumber.TRIM_HORIZON);
newLease.ownerSwitchesSinceCheckpoint(0L);
newLease.streamIdentifier(streamIdentifier.serialize());
newLease.shardId(childShard.shardId());
return newLease;
}
/**
* Helper method to create a new Lease POJO for a shard.
* Note: Package level access only for testing purposes

View file

@ -0,0 +1,27 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package software.amazon.kinesis.leases.exceptions;
/**
* Exception type for all exceptions thrown by the customer implemented code.
*/
public class CustomerApplicationException extends Exception {
public CustomerApplicationException(Throwable e) { super(e);}
public CustomerApplicationException(String message, Throwable e) { super(message, e);}
public CustomerApplicationException(String message) { super(message);}
}

View file

@ -496,7 +496,8 @@ class ConsumerStates {
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory());
argument.metricsFactory(),
input == null ? null : input.childShards());
}
@Override

View file

@ -16,6 +16,7 @@ package software.amazon.kinesis.lifecycle;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
@ -32,6 +33,7 @@ import lombok.Getter;
import lombok.NonNull;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.exceptions.internal.BlockedOnParentShardException;
import software.amazon.kinesis.leases.ShardInfo;
@ -86,6 +88,8 @@ public class ShardConsumer {
private final ShardConsumerSubscriber subscriber;
private ProcessRecordsInput shardEndProcessRecordsInput;
@Deprecated
public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo,
Optional<Long> logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument,
@ -148,6 +152,7 @@ public class ShardConsumer {
processData(input);
if (taskOutcome == TaskOutcome.END_OF_SHARD) {
markForShutdown(ShutdownReason.SHARD_END);
shardEndProcessRecordsInput = input;
subscription.cancel();
return;
}
@ -305,7 +310,7 @@ public class ShardConsumer {
return true;
}
executeTask(null);
executeTask(shardEndProcessRecordsInput);
return false;
}
}, executorService);

View file

@ -20,6 +20,7 @@ import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
@ -30,6 +31,10 @@ import software.amazon.kinesis.leases.LeaseCoordinator;
import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.leases.HierarchicalShardSyncer;
import software.amazon.kinesis.leases.exceptions.CustomerApplicationException;
import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
import software.amazon.kinesis.lifecycle.events.LeaseLostInput;
import software.amazon.kinesis.lifecycle.events.ShardEndedInput;
import software.amazon.kinesis.metrics.MetricsFactory;
@ -66,8 +71,6 @@ public class ShutdownTask implements ConsumerTask {
@NonNull
private final InitialPositionInStreamExtended initialPositionInStream;
private final boolean cleanupLeasesOfCompletedShards;
private final boolean garbageCollectLeases = false;
private final boolean isLeaseTableEmpty = false;
private final boolean ignoreUnexpectedChildShards;
@NonNull
private final LeaseCoordinator leaseCoordinator;
@ -81,6 +84,8 @@ public class ShutdownTask implements ConsumerTask {
private final TaskType taskType = TaskType.SHUTDOWN;
private final List<ChildShard> childShards;
private static final Function<ShardInfo, String> shardInfoIdProvider = shardInfo -> shardInfo
.streamIdentifierSerOpt().map(s -> s + ":" + shardInfo.shardId()).orElse(shardInfo.shardId());
/*
@ -95,85 +100,48 @@ public class ShutdownTask implements ConsumerTask {
final MetricsScope scope = MetricsUtil.createMetricsWithOperation(metricsFactory, SHUTDOWN_TASK_OPERATION);
Exception exception;
boolean applicationException = false;
try {
try {
ShutdownReason localReason = reason;
List<Shard> latestShards = null;
/*
* Revalidate if the current shard is closed before shutting down the shard consumer with reason SHARD_END
* If current shard is not closed, shut down the shard consumer with reason LEASE_LOST that allows
* active workers to contend for the lease of this shard.
*/
if (localReason == ShutdownReason.SHARD_END) {
latestShards = shardDetector.listShards();
log.debug("Invoking shutdown() for shard {}, concurrencyToken {}. Shutdown reason: {}",
shardInfoIdProvider.apply(shardInfo), shardInfo.concurrencyToken(), reason);
//If latestShards is empty, should also shutdown the ShardConsumer without checkpoint with SHARD_END
if (CollectionUtils.isNullOrEmpty(latestShards) || !isShardInContextParentOfAny(latestShards)) {
localReason = ShutdownReason.LEASE_LOST;
dropLease();
log.info("Forcing the lease to be lost before shutting down the consumer for Shard: " + shardInfoIdProvider.apply(shardInfo));
final long startTime = System.currentTimeMillis();
if (reason == ShutdownReason.SHARD_END) {
// Create new lease for the child shards if they don't exist.
if (!CollectionUtils.isNullOrEmpty(childShards)) {
createLeasesForChildShardsIfNotExist();
} else {
log.warn("Shard {} no longer exists. Shutting down consumer with SHARD_END reason without creating leases for child shards.", shardInfoIdProvider.apply(shardInfo));
}
}
// If we reached end of the shard, set sequence number to SHARD_END.
if (localReason == ShutdownReason.SHARD_END) {
recordProcessorCheckpointer
.sequenceNumberAtShardEnd(recordProcessorCheckpointer.largestPermittedCheckpointValue());
recordProcessorCheckpointer.largestPermittedCheckpointValue(ExtendedSequenceNumber.SHARD_END);
// Call the shardRecordsProcessor to checkpoint with SHARD_END sequence number.
// The shardEnded is implemented by customer. We should validate if the SHARD_END checkpointing is successful after calling shardEnded.
throwOnApplicationException(() -> applicationCheckpointAndVerification(), scope, startTime);
} else {
throwOnApplicationException(() -> shardRecordProcessor.leaseLost(LeaseLostInput.builder().build()), scope, startTime);
}
log.debug("Invoking shutdown() for shard {}, concurrencyToken {}. Shutdown reason: {}",
shardInfoIdProvider.apply(shardInfo), shardInfo.concurrencyToken(), localReason);
final ShutdownInput shutdownInput = ShutdownInput.builder().shutdownReason(localReason)
.checkpointer(recordProcessorCheckpointer).build();
final long startTime = System.currentTimeMillis();
try {
if (localReason == ShutdownReason.SHARD_END) {
shardRecordProcessor.shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
ExtendedSequenceNumber lastCheckpointValue = recordProcessorCheckpointer.lastCheckpointValue();
if (lastCheckpointValue == null
|| !lastCheckpointValue.equals(ExtendedSequenceNumber.SHARD_END)) {
throw new IllegalArgumentException("Application didn't checkpoint at end of shard "
+ shardInfoIdProvider.apply(shardInfo) + ". Application must checkpoint upon shard end. " +
"See ShardRecordProcessor.shardEnded javadocs for more information.");
}
} else {
shardRecordProcessor.leaseLost(LeaseLostInput.builder().build());
}
log.debug("Shutting down retrieval strategy.");
recordsPublisher.shutdown();
log.debug("Record processor completed shutdown() for shard {}", shardInfoIdProvider.apply(shardInfo));
} catch (Exception e) {
applicationException = true;
throw e;
} finally {
MetricsUtil.addLatency(scope, RECORD_PROCESSOR_SHUTDOWN_METRIC, startTime, MetricsLevel.SUMMARY);
}
if (localReason == ShutdownReason.SHARD_END) {
log.debug("Looking for child shards of shard {}", shardInfoIdProvider.apply(shardInfo));
// create leases for the child shards
hierarchicalShardSyncer.checkAndCreateLeaseForNewShards(shardDetector, leaseCoordinator.leaseRefresher(),
initialPositionInStream, latestShards, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, scope, garbageCollectLeases,
isLeaseTableEmpty);
log.debug("Finished checking for child shards of shard {}", shardInfoIdProvider.apply(shardInfo));
}
log.debug("Shutting down retrieval strategy for shard {}.", shardInfoIdProvider.apply(shardInfo));
recordsPublisher.shutdown();
log.debug("Record processor completed shutdown() for shard {}", shardInfoIdProvider.apply(shardInfo));
return new TaskResult(null);
} catch (Exception e) {
if (applicationException) {
log.error("Application exception. ", e);
if (e instanceof CustomerApplicationException) {
log.error("Shard {}: Application exception. ", shardInfoIdProvider.apply(shardInfo), e);
} else {
log.error("Caught exception: ", e);
log.error("Shard {}: Caught exception: ", shardInfoIdProvider.apply(shardInfo), e);
}
exception = e;
// backoff if we encounter an exception.
try {
Thread.sleep(this.backoffTimeMillis);
} catch (InterruptedException ie) {
log.debug("Interrupted sleep", ie);
log.debug("Shard {}: Interrupted sleep", shardInfoIdProvider.apply(shardInfo), ie);
}
}
} finally {
@ -181,7 +149,37 @@ public class ShutdownTask implements ConsumerTask {
}
return new TaskResult(exception);
}
private void applicationCheckpointAndVerification() {
shardRecordProcessor.shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
final ExtendedSequenceNumber lastCheckpointValue = recordProcessorCheckpointer.lastCheckpointValue();
if (lastCheckpointValue == null
|| !lastCheckpointValue.equals(ExtendedSequenceNumber.SHARD_END)) {
throw new IllegalArgumentException("Application didn't checkpoint at end of shard "
+ shardInfoIdProvider.apply(shardInfo) + ". Application must checkpoint upon shard end. " +
"See ShardRecordProcessor.shardEnded javadocs for more information.");
}
}
private void throwOnApplicationException(Runnable action, MetricsScope metricsScope, final long startTime) throws CustomerApplicationException {
try {
action.run();
} catch (Exception e) {
throw new CustomerApplicationException("Customer application throws exception for shard " + shardInfoIdProvider.apply(shardInfo) +": ", e);
} finally {
MetricsUtil.addLatency(metricsScope, RECORD_PROCESSOR_SHUTDOWN_METRIC, startTime, MetricsLevel.SUMMARY);
}
}
private void createLeasesForChildShardsIfNotExist()
throws DependencyException, InvalidStateException, ProvisionedThroughputException {
for(ChildShard childShard : childShards) {
if(leaseCoordinator.getCurrentlyHeldLease(childShard.shardId()) == null) {
final Lease leaseToCreate = hierarchicalShardSyncer.createLeaseForChildShard(childShard, shardDetector.streamIdentifier());
leaseCoordinator.leaseRefresher().createLeaseIfNotExists(leaseToCreate);
}
}
}
/*

View file

@ -23,6 +23,7 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.Accessors;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.kinesis.processor.ShardRecordProcessor;
import software.amazon.kinesis.processor.RecordProcessorCheckpointer;
import software.amazon.kinesis.retrieval.KinesisClientRecord;
@ -66,6 +67,11 @@ public class ProcessRecordsInput {
* This value does not include the {@link #timeSpentInCache()}.
*/
private Long millisBehindLatest;
/**
* A list of child shards if the current GetRecords request reached the shard end.
* If not at the shard end, this should be an empty list.
*/
private List<ChildShard> childShards;
/**
* How long the records spent waiting to be dispatched to the {@link ShardRecordProcessor}

View file

@ -33,11 +33,13 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.Either;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.KinesisRequestsBuilder;
import software.amazon.kinesis.common.RequestDetails;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.BatchUniqueIdentifier;
import software.amazon.kinesis.retrieval.IteratorBuilder;
@ -398,7 +400,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
// The ack received for this onNext event will be ignored by the publisher as the global flow object should
// be either null or renewed when the ack's flow identifier is evaluated.
FanoutRecordsRetrieved response = new FanoutRecordsRetrieved(
ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).build(), null,
ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).childShards(Collections.emptyList()).build(), null,
triggeringFlow != null ? triggeringFlow.getSubscribeToShardId() : shardId + "-no-flow-found");
subscriber.onNext(response);
subscriber.onComplete();
@ -477,15 +479,28 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
return;
}
List<KinesisClientRecord> records = recordBatchEvent.records().stream().map(KinesisClientRecord::fromRecord)
.collect(Collectors.toList());
ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now())
.millisBehindLatest(recordBatchEvent.millisBehindLatest())
.isAtShardEnd(recordBatchEvent.continuationSequenceNumber() == null).records(records).build();
FanoutRecordsRetrieved recordsRetrieved = new FanoutRecordsRetrieved(input,
recordBatchEvent.continuationSequenceNumber(), triggeringFlow.subscribeToShardId);
try {
// If recordBatchEvent is not valid event, RuntimeException will be thrown here and trigger the errorOccurred call.
// Since the triggeringFlow is active flow, it will then trigger the handleFlowError call.
// Since the exception is not ResourceNotFoundException, it will trigger onError in the ShardConsumerSubscriber.
// The ShardConsumerSubscriber will finally cancel the subscription.
if (!isValidEvent(recordBatchEvent)) {
throw new InvalidStateException("RecordBatchEvent for flow " + triggeringFlow.toString() + " is invalid."
+ " event.continuationSequenceNumber: " + recordBatchEvent.continuationSequenceNumber()
+ ". event.childShards: " + recordBatchEvent.childShards());
}
List<KinesisClientRecord> records = recordBatchEvent.records().stream().map(KinesisClientRecord::fromRecord)
.collect(Collectors.toList());
ProcessRecordsInput input = ProcessRecordsInput.builder()
.cacheEntryTime(Instant.now())
.millisBehindLatest(recordBatchEvent.millisBehindLatest())
.isAtShardEnd(recordBatchEvent.continuationSequenceNumber() == null)
.records(records)
.childShards(recordBatchEvent.childShards())
.build();
FanoutRecordsRetrieved recordsRetrieved = new FanoutRecordsRetrieved(input,
recordBatchEvent.continuationSequenceNumber(), triggeringFlow.subscribeToShardId);
bufferCurrentEventAndScheduleIfRequired(recordsRetrieved, triggeringFlow);
} catch (Throwable t) {
log.warn("{}: Unable to buffer or schedule onNext for subscriber. Failing publisher." +
@ -495,6 +510,11 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
}
}
private boolean isValidEvent(SubscribeToShardEvent event) {
return event.continuationSequenceNumber() == null ? !CollectionUtils.isNullOrEmpty(event.childShards())
: event.childShards() != null && event.childShards().isEmpty();
}
private void updateAvailableQueueSpaceAndRequestUpstream(RecordFlow triggeringFlow) {
if (availableQueueSpace <= 0) {
log.debug(

View file

@ -67,6 +67,7 @@ public class BlockingRecordsPublisher implements RecordsPublisher {
return ProcessRecordsInput.builder()
.records(records)
.millisBehindLatest(getRecordsResult.millisBehindLatest())
.childShards(getRecordsResult.childShards())
.build();
}

View file

@ -36,6 +36,7 @@ import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse;
import software.amazon.awssdk.services.kinesis.model.KinesisException;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.common.FutureUtils;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
@ -133,8 +134,12 @@ public class KinesisDataFetcher {
final DataFetcherResult TERMINAL_RESULT = new DataFetcherResult() {
@Override
public GetRecordsResponse getResult() {
return GetRecordsResponse.builder().millisBehindLatest(null).records(Collections.emptyList())
.nextShardIterator(null).build();
return GetRecordsResponse.builder()
.millisBehindLatest(null)
.records(Collections.emptyList())
.nextShardIterator(null)
.childShards(Collections.emptyList())
.build();
}
@Override
@ -281,6 +286,11 @@ public class KinesisDataFetcher {
try {
final GetRecordsResponse response = FutureUtils.resolveOrCancelFuture(kinesisClient.getRecords(request),
maxFutureWait);
if (!isValidResponse(response)) {
throw new RetryableRetrievalException("GetRecords response is not valid for shard: " + streamAndShardId
+ ". nextShardIterator: " + response.nextShardIterator()
+ ". childShards: " + response.childShards() + ". Will retry GetRecords with the same nextIterator.");
}
success = true;
return response;
} catch (ExecutionException e) {
@ -298,6 +308,11 @@ public class KinesisDataFetcher {
}
}
private boolean isValidResponse(GetRecordsResponse response) {
return response.nextShardIterator() == null ? !CollectionUtils.isNullOrEmpty(response.childShards())
: response.childShards() != null && response.childShards().isEmpty();
}
private AWSExceptionManager createExceptionManager() {
final AWSExceptionManager exceptionManager = new AWSExceptionManager();
exceptionManager.add(ResourceNotFoundException.class, t -> t);

View file

@ -162,7 +162,7 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
} else {
log.info(
"{}: No record batch found while evicting from the prefetch queue. This indicates the prefetch buffer"
+ "was reset.", streamAndShardId);
+ " was reset.", streamAndShardId);
}
return result;
}
@ -437,6 +437,7 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
.millisBehindLatest(getRecordsResult.millisBehindLatest())
.cacheEntryTime(lastSuccessfulCall)
.isAtShardEnd(getRecordsRetrievalStrategy.getDataFetcher().isShardEndReached())
.childShards(getRecordsResult.childShards())
.build();
PrefetchRecordsRetrieved recordsRetrieved = new PrefetchRecordsRetrieved(processRecordsInput,

View file

@ -25,6 +25,8 @@ import static org.mockito.Mockito.when;
import static software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
@ -33,13 +35,13 @@ import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeDiagnosingMatcher;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer;
import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
@ -49,6 +51,7 @@ import software.amazon.kinesis.leases.LeaseRefresher;
import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.leases.HierarchicalShardSyncer;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.metrics.MetricsFactory;
import software.amazon.kinesis.processor.Checkpointer;
@ -57,6 +60,7 @@ import software.amazon.kinesis.processor.ShardRecordProcessor;
import software.amazon.kinesis.retrieval.AggregatorUtil;
import software.amazon.kinesis.retrieval.RecordsPublisher;
@RunWith(MockitoJUnitRunner.class)
public class ConsumerStatesTest {
private static final String STREAM_NAME = "TestStream";
@ -300,13 +304,27 @@ public class ConsumerStatesTest {
}
// TODO: Fix this test
@Ignore
@Test
public void shuttingDownStateTest() {
consumer.markForShutdown(ShutdownReason.SHARD_END);
ConsumerState state = ShardConsumerState.SHUTTING_DOWN.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null);
List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add("shardId-000000000000");
ChildShard leftChild = ChildShard.builder()
.shardId("shardId-000000000001")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("shardId-000000000002")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
when(processRecordsInput.childShards()).thenReturn(childShards);
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput);
assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task,
@ -315,8 +333,6 @@ public class ConsumerStatesTest {
equalTo(recordProcessorCheckpointer)));
assertThat(task, shutdownTask(ShutdownReason.class, "reason", equalTo(reason)));
assertThat(task, shutdownTask(LeaseCoordinator.class, "leaseCoordinator", equalTo(leaseCoordinator)));
assertThat(task, shutdownTask(InitialPositionInStreamExtended.class, "initialPositionInStream",
equalTo(initialPositionInStream)));
assertThat(task,
shutdownTask(Boolean.class, "cleanupLeasesOfCompletedShards", equalTo(cleanupLeasesOfCompletedShards)));
assertThat(task, shutdownTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));

View file

@ -26,6 +26,7 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -33,9 +34,11 @@ import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.SequenceNumberRange;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer;
@ -43,11 +46,16 @@ import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.exceptions.internal.KinesisClientLibIOException;
import software.amazon.kinesis.leases.HierarchicalShardSyncer;
import software.amazon.kinesis.leases.Lease;
import software.amazon.kinesis.leases.LeaseCoordinator;
import software.amazon.kinesis.leases.LeaseRefresher;
import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.leases.exceptions.CustomerApplicationException;
import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
import software.amazon.kinesis.lifecycle.events.LeaseLostInput;
import software.amazon.kinesis.lifecycle.events.ShardEndedInput;
import software.amazon.kinesis.metrics.MetricsFactory;
@ -104,7 +112,7 @@ public class ShutdownTaskTest {
task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher,
hierarchicalShardSyncer, NULL_METRICS_FACTORY);
hierarchicalShardSyncer, NULL_METRICS_FACTORY, constructChildShards());
}
/**
@ -113,12 +121,12 @@ public class ShutdownTaskTest {
*/
@Test
public final void testCallWhenApplicationDoesNotCheckpoint() {
when(shardDetector.listShards()).thenReturn(constructShardListGraphA());
when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(new ExtendedSequenceNumber("3298"));
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
final TaskResult result = task.call();
assertNotNull(result.getException());
assertTrue(result.getException() instanceof IllegalArgumentException);
assertTrue(result.getException() instanceof CustomerApplicationException);
}
/**
@ -126,28 +134,18 @@ public class ShutdownTaskTest {
* This test is for the scenario that checkAndCreateLeaseForNewShards throws an exception.
*/
@Test
public final void testCallWhenSyncingShardsThrows() throws Exception {
final boolean garbageCollectLeases = false;
final boolean isLeaseTableEmpty = false;
List<Shard> latestShards = constructShardListGraphA();
when(shardDetector.listShards()).thenReturn(latestShards);
public final void testCallWhenCreatingNewLeasesThrows() throws Exception {
when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END);
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
doAnswer((invocation) -> {
throw new KinesisClientLibIOException("KinesisClientLibIOException");
}).when(hierarchicalShardSyncer)
.checkAndCreateLeaseForNewShards(shardDetector, leaseRefresher, INITIAL_POSITION_TRIM_HORIZON,
latestShards, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards,
NULL_METRICS_FACTORY.createMetrics(), garbageCollectLeases, isLeaseTableEmpty);
when(leaseRefresher.createLeaseIfNotExists(Matchers.any(Lease.class))).thenThrow(new KinesisClientLibIOException("KinesisClientLibIOException"));
final TaskResult result = task.call();
assertNotNull(result.getException());
assertTrue(result.getException() instanceof KinesisClientLibIOException);
verify(recordsPublisher).shutdown();
verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(recordsPublisher, never()).shutdown();
verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build());
verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class));
}
/**
@ -155,24 +153,24 @@ public class ShutdownTaskTest {
* This test is for the scenario that ShutdownTask is created for ShardConsumer reaching the Shard End.
*/
@Test
public final void testCallWhenTrueShardEnd() {
public final void testCallWhenTrueShardEnd() throws DependencyException, InvalidStateException, ProvisionedThroughputException {
shardInfo = new ShardInfo("shardId-0", concurrencyToken, Collections.emptySet(),
ExtendedSequenceNumber.LATEST);
task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher,
hierarchicalShardSyncer, NULL_METRICS_FACTORY);
hierarchicalShardSyncer, NULL_METRICS_FACTORY, constructChildShards());
when(shardDetector.listShards()).thenReturn(constructShardListGraphA());
when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END);
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
final TaskResult result = task.call();
assertNull(result.getException());
verify(recordsPublisher).shutdown();
verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build());
verify(shardDetector, times(1)).listShards();
verify(leaseCoordinator, never()).getAssignments();
verify(leaseRefresher, times(2)).createLeaseIfNotExists(Matchers.any(Lease.class));
verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class));
}
/**
@ -180,23 +178,25 @@ public class ShutdownTaskTest {
* This test is for the scenario that a ShutdownTask is created for detecting a false Shard End.
*/
@Test
public final void testCallWhenFalseShardEnd() {
public final void testCallWhenShardNotFound() throws DependencyException, InvalidStateException, ProvisionedThroughputException {
shardInfo = new ShardInfo("shardId-4", concurrencyToken, Collections.emptySet(),
ExtendedSequenceNumber.LATEST);
task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher,
hierarchicalShardSyncer, NULL_METRICS_FACTORY);
hierarchicalShardSyncer, NULL_METRICS_FACTORY, new ArrayList<>());
when(shardDetector.listShards()).thenReturn(constructShardListGraphA());
when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END);
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
final TaskResult result = task.call();
assertNull(result.getException());
verify(recordsPublisher).shutdown();
verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor).leaseLost(LeaseLostInput.builder().build());
verify(shardDetector, times(1)).listShards();
verify(leaseCoordinator).getCurrentlyHeldLease(shardInfo.shardId());
verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build());
verify(leaseCoordinator, never()).getCurrentlyHeldLease(shardInfo.shardId());
verify(leaseRefresher, never()).createLeaseIfNotExists(Matchers.any(Lease.class));
verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class));
}
/**
@ -204,23 +204,22 @@ public class ShutdownTaskTest {
* This test is for the scenario that a ShutdownTask is created for the ShardConsumer losing the lease.
*/
@Test
public final void testCallWhenLeaseLost() {
public final void testCallWhenLeaseLost() throws DependencyException, InvalidStateException, ProvisionedThroughputException {
shardInfo = new ShardInfo("shardId-4", concurrencyToken, Collections.emptySet(),
ExtendedSequenceNumber.LATEST);
task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
LEASE_LOST_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher,
hierarchicalShardSyncer, NULL_METRICS_FACTORY);
when(shardDetector.listShards()).thenReturn(constructShardListGraphA());
hierarchicalShardSyncer, NULL_METRICS_FACTORY, new ArrayList<>());
final TaskResult result = task.call();
assertNull(result.getException());
verify(recordsPublisher).shutdown();
verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor).leaseLost(LeaseLostInput.builder().build());
verify(shardDetector, never()).listShards();
verify(leaseCoordinator, never()).getAssignments();
verify(leaseRefresher, never()).createLeaseIfNotExists(Matchers.any(Lease.class));
verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class));
}
/**
@ -231,45 +230,23 @@ public class ShutdownTaskTest {
assertEquals(TaskType.SHUTDOWN, task.taskType());
}
/*
* Helper method to construct a shard list for graph A. Graph A is defined below. Shard structure (y-axis is
* epochs): 0 1 2 3 4 5 - shards till
* \ / \ / | |
* 6 7 4 5 - shards from epoch 103 - 205
* \ / | /\
* 8 4 9 10 - shards from epoch 206 (open - no ending sequenceNumber)
*/
private List<Shard> constructShardListGraphA() {
final SequenceNumberRange range0 = ShardObjectHelper.newSequenceNumberRange("11", "102");
final SequenceNumberRange range1 = ShardObjectHelper.newSequenceNumberRange("11", null);
final SequenceNumberRange range2 = ShardObjectHelper.newSequenceNumberRange("11", "205");
final SequenceNumberRange range3 = ShardObjectHelper.newSequenceNumberRange("103", "205");
final SequenceNumberRange range4 = ShardObjectHelper.newSequenceNumberRange("206", null);
return Arrays.asList(
ShardObjectHelper.newShard("shardId-0", null, null, range0,
ShardObjectHelper.newHashKeyRange("0", "99")),
ShardObjectHelper.newShard("shardId-1", null, null, range0,
ShardObjectHelper.newHashKeyRange("100", "199")),
ShardObjectHelper.newShard("shardId-2", null, null, range0,
ShardObjectHelper.newHashKeyRange("200", "299")),
ShardObjectHelper.newShard("shardId-3", null, null, range0,
ShardObjectHelper.newHashKeyRange("300", "399")),
ShardObjectHelper.newShard("shardId-4", null, null, range1,
ShardObjectHelper.newHashKeyRange("400", "499")),
ShardObjectHelper.newShard("shardId-5", null, null, range2,
ShardObjectHelper.newHashKeyRange("500", ShardObjectHelper.MAX_HASH_KEY)),
ShardObjectHelper.newShard("shardId-6", "shardId-0", "shardId-1", range3,
ShardObjectHelper.newHashKeyRange("0", "199")),
ShardObjectHelper.newShard("shardId-7", "shardId-2", "shardId-3", range3,
ShardObjectHelper.newHashKeyRange("200", "399")),
ShardObjectHelper.newShard("shardId-8", "shardId-6", "shardId-7", range4,
ShardObjectHelper.newHashKeyRange("0", "399")),
ShardObjectHelper.newShard("shardId-9", "shardId-5", null, range4,
ShardObjectHelper.newHashKeyRange("500", "799")),
ShardObjectHelper.newShard("shardId-10", null, "shardId-5", range4,
ShardObjectHelper.newHashKeyRange("800", ShardObjectHelper.MAX_HASH_KEY)));
private List<ChildShard> constructChildShards() {
List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add(shardId);
ChildShard leftChild = ChildShard.builder()
.shardId("ShardId-1")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("ShardId-2")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
return childShards;
}
}

View file

@ -26,6 +26,7 @@ import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;
@ -35,6 +36,7 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.BatchUniqueIdentifier;
@ -47,6 +49,7 @@ import software.amazon.kinesis.utils.SubscribeToShardRequestMatcher;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
@ -89,6 +92,7 @@ public class FanOutRecordsPublisherTest {
private static final String SHARD_ID = "Shard-001";
private static final String CONSUMER_ARN = "arn:consumer";
private static final String CONTINUATION_SEQUENCE_NUMBER = "continuationSequenceNumber";
@Mock
private KinesisAsyncClient kinesisClient;
@ -148,7 +152,12 @@ public class FanOutRecordsPublisherTest {
List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
.collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build();
batchEvent = SubscribeToShardEvent.builder()
.millisBehindLatest(100L)
.records(records)
.continuationSequenceNumber("test")
.childShards(Collections.emptyList())
.build();
captor.getValue().onNext(batchEvent);
captor.getValue().onNext(batchEvent);
@ -166,6 +175,73 @@ public class FanOutRecordsPublisherTest {
}
@Test
public void InvalidEventTest() throws Exception {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordFlow.class);
doNothing().when(publisher).subscribe(captor.capture());
source.start(ExtendedSequenceNumber.LATEST,
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));
List<ProcessRecordsInput> receivedInput = new ArrayList<>();
source.subscribe(new ShardConsumerNotifyingSubscriber(new Subscriber<RecordsRetrieved>() {
Subscription subscription;
@Override public void onSubscribe(Subscription s) {
subscription = s;
subscription.request(1);
}
@Override public void onNext(RecordsRetrieved input) {
receivedInput.add(input.processRecordsInput());
subscription.request(1);
}
@Override public void onError(Throwable t) {
log.error("Caught throwable in subscriber", t);
fail("Caught throwable in subscriber");
}
@Override public void onComplete() {
fail("OnComplete called when not expected");
}
}, source));
verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
flowCaptor.getValue().onEventStream(publisher);
captor.getValue().onSubscribe(subscription);
List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());
List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
.collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).continuationSequenceNumber(CONTINUATION_SEQUENCE_NUMBER).build();
SubscribeToShardEvent invalidEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).childShards(Collections.emptyList()).build();
captor.getValue().onNext(batchEvent);
captor.getValue().onNext(invalidEvent);
captor.getValue().onNext(batchEvent);
// When the second request failed with invalid event, it should stop sending requests and cancel the flow.
verify(subscription, times(2)).request(1);
assertThat(receivedInput.size(), equalTo(1));
receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
assertThat(clientRecordsList.size(), equalTo(matchers.size()));
for (int i = 0; i < clientRecordsList.size(); ++i) {
assertThat(clientRecordsList.get(i), matchers.get(i));
}
});
}
@Test
public void testIfAllEventsReceivedWhenNoTasksRejectedByExecutor() throws Exception {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
@ -225,7 +301,9 @@ public class FanOutRecordsPublisherTest {
SubscribeToShardEvent.builder()
.millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum)
.records(records).build())
.records(records)
.childShards(Collections.emptyList())
.build())
.forEach(batchEvent -> captor.getValue().onNext(batchEvent));
verify(subscription, times(4)).request(1);
@ -301,7 +379,9 @@ public class FanOutRecordsPublisherTest {
SubscribeToShardEvent.builder()
.millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum)
.records(records).build())
.records(records)
.childShards(Collections.emptyList())
.build())
.forEach(batchEvent -> captor.getValue().onNext(batchEvent));
verify(subscription, times(2)).request(1);
@ -334,6 +414,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "")
.records(records)
.childShards(Collections.emptyList())
.build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -436,6 +517,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "")
.records(records)
.childShards(Collections.emptyList())
.build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -536,13 +618,30 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "")
.records(records)
.childShards(Collections.emptyList())
.build());
List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add(SHARD_ID);
ChildShard leftChild = ChildShard.builder()
.shardId("Shard-002")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("Shard-003")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
Consumer<Integer> servicePublisherShardEndAction = contSeqNum -> captor.getValue().onNext(
SubscribeToShardEvent.builder()
.millisBehindLatest(100L)
.continuationSequenceNumber(null)
.records(records)
.childShards(childShards)
.build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -648,6 +747,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "")
.records(records)
.childShards(Collections.emptyList())
.build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -750,6 +850,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "")
.records(records)
.childShards(Collections.emptyList())
.build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -842,6 +943,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "")
.records(records)
.childShards(Collections.emptyList())
.build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(1);
@ -1004,7 +1106,12 @@ public class FanOutRecordsPublisherTest {
List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
.collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build();
batchEvent = SubscribeToShardEvent.builder()
.millisBehindLatest(100L)
.records(records)
.continuationSequenceNumber(CONTINUATION_SEQUENCE_NUMBER)
.childShards(Collections.emptyList())
.build();
captor.getValue().onNext(batchEvent);
captor.getValue().onNext(batchEvent);
@ -1098,7 +1205,7 @@ public class FanOutRecordsPublisherTest {
.collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records)
.continuationSequenceNumber("3").build();
.continuationSequenceNumber("3").childShards(Collections.emptyList()).build();
captor.getValue().onNext(batchEvent);
captor.getValue().onComplete();
@ -1126,7 +1233,7 @@ public class FanOutRecordsPublisherTest {
.collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(nextRecords)
.continuationSequenceNumber("6").build();
.continuationSequenceNumber("6").childShards(Collections.emptyList()).build();
nextSubscribeCaptor.getValue().onNext(batchEvent);
verify(subscription, times(4)).request(1);

View file

@ -30,6 +30,7 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
@ -53,6 +54,7 @@ import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest;
import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
@ -65,6 +67,7 @@ import software.amazon.kinesis.checkpoint.SentinelCheckpoint;
import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.exceptions.KinesisClientLibException;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.metrics.MetricsFactory;
import software.amazon.kinesis.metrics.NullMetricsFactory;
import software.amazon.kinesis.processor.Checkpointer;
@ -330,8 +333,31 @@ public class KinesisDataFetcherTest {
private CompletableFuture<GetRecordsResponse> makeGetRecordsResponse(String nextIterator, List<Record> records)
throws InterruptedException, ExecutionException {
List<ChildShard> childShards = new ArrayList<>();
if(nextIterator == null) {
childShards = createChildShards();
}
return CompletableFuture.completedFuture(GetRecordsResponse.builder().nextShardIterator(nextIterator)
.records(CollectionUtils.isNullOrEmpty(records) ? Collections.emptyList() : records).build());
.records(CollectionUtils.isNullOrEmpty(records) ? Collections.emptyList() : records).childShards(childShards).build());
}
private List<ChildShard> createChildShards() {
List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add(SHARD_ID);
ChildShard leftChild = ChildShard.builder()
.shardId("Shard-2")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("Shard-3")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
return childShards;
}
@Test
@ -342,6 +368,7 @@ public class KinesisDataFetcherTest {
final String initialIterator = "InitialIterator";
final String nextIterator1 = "NextIteratorOne";
final String nextIterator2 = "NextIteratorTwo";
final String nextIterator3 = "NextIteratorThree";
final CompletableFuture<GetRecordsResponse> nonAdvancingResult1 = makeGetRecordsResponse(initialIterator, null);
final CompletableFuture<GetRecordsResponse> nonAdvancingResult2 = makeGetRecordsResponse(nextIterator1, null);
final CompletableFuture<GetRecordsResponse> finalNonAdvancingResult = makeGetRecordsResponse(nextIterator2,

View file

@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
@ -48,6 +49,7 @@ import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import lombok.extern.slf4j.Slf4j;
@ -86,6 +88,7 @@ public class PrefetchRecordsPublisherIntegrationTest {
private String operation = "ProcessTask";
private String streamName = "streamName";
private String shardId = "shardId-000000000000";
private String nextShardIterator = "testNextShardIterator";
@Mock
private KinesisAsyncClient kinesisClient;
@ -249,7 +252,7 @@ public class PrefetchRecordsPublisherIntegrationTest {
@Override
public DataFetcherResult getRecords() {
GetRecordsResponse getRecordsResult = GetRecordsResponse.builder().records(new ArrayList<>(records)).millisBehindLatest(1000L).build();
GetRecordsResponse getRecordsResult = GetRecordsResponse.builder().records(new ArrayList<>(records)).nextShardIterator(nextShardIterator).millisBehindLatest(1000L).build();
return new AdvancingResult(getRecordsResult);
}

View file

@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
@ -71,11 +72,13 @@ import io.reactivex.Flowable;
import io.reactivex.schedulers.Schedulers;
import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException;
import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.RequestDetails;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.common.StreamIdentifier;
import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
@ -99,6 +102,7 @@ public class PrefetchRecordsPublisherTest {
private static final int MAX_SIZE = 5;
private static final int MAX_RECORDS_COUNT = 15000;
private static final long IDLE_MILLIS_BETWEEN_CALLS = 0L;
private static final String NEXT_SHARD_ITERATOR = "testNextShardIterator";
@Mock
private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
@ -136,7 +140,7 @@ public class PrefetchRecordsPublisherTest {
"shardId");
spyQueue = spy(getRecordsCache.getPublisherSession().prefetchRecordsQueue());
records = spy(new ArrayList<>());
getRecordsResponse = GetRecordsResponse.builder().records(records).build();
getRecordsResponse = GetRecordsResponse.builder().records(records).nextShardIterator(NEXT_SHARD_ITERATOR).childShards(new ArrayList<>()).build();
when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(getRecordsResponse);
}
@ -155,11 +159,67 @@ public class PrefetchRecordsPublisherTest {
.processRecordsInput();
assertEquals(expectedRecords, result.records());
assertEquals(new ArrayList<>(), result.childShards());
verify(executorService).execute(any());
verify(getRecordsRetrievalStrategy, atLeast(1)).getRecords(eq(MAX_RECORDS_PER_CALL));
}
@Test
public void testGetRecordsWithInvalidResponse() {
record = Record.builder().data(createByteBufferWithSize(SIZE_512_KB)).build();
when(records.size()).thenReturn(1000);
GetRecordsResponse response = GetRecordsResponse.builder().records(records).build();
when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(response);
when(dataFetcher.isShardEndReached()).thenReturn(false);
getRecordsCache.start(sequenceNumber, initialPosition);
try {
ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L)
.processRecordsInput();
} catch (Exception e) {
assertEquals("No records found", e.getMessage());
}
}
@Test
public void testGetRecordsWithShardEnd() {
records = new ArrayList<>();
final List<KinesisClientRecord> expectedRecords = new ArrayList<>();
List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add("shardId");
ChildShard leftChild = ChildShard.builder()
.shardId("shardId-000000000001")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("shardId-000000000002")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
GetRecordsResponse response = GetRecordsResponse.builder().records(records).childShards(childShards).build();
when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(response);
when(dataFetcher.isShardEndReached()).thenReturn(true);
getRecordsCache.start(sequenceNumber, initialPosition);
ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L)
.processRecordsInput();
assertEquals(expectedRecords, result.records());
assertEquals(childShards, result.childShards());
assertTrue(result.isAtShardEnd());
}
// TODO: Broken test
@Test
@Ignore
@ -270,7 +330,7 @@ public class PrefetchRecordsPublisherTest {
@Test
public void testRetryableRetrievalExceptionContinues() {
GetRecordsResponse response = GetRecordsResponse.builder().millisBehindLatest(100L).records(Collections.emptyList()).build();
GetRecordsResponse response = GetRecordsResponse.builder().millisBehindLatest(100L).records(Collections.emptyList()).nextShardIterator(NEXT_SHARD_ITERATOR).build();
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenThrow(new RetryableRetrievalException("Timeout", new TimeoutException("Timeout"))).thenReturn(response);
getRecordsCache.start(sequenceNumber, initialPosition);
@ -293,7 +353,7 @@ public class PrefetchRecordsPublisherTest {
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenAnswer( i -> GetRecordsResponse.builder().records(
Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber(++sequenceNumberInResponse[0] + "").build())
.build());
.nextShardIterator(NEXT_SHARD_ITERATOR).build());
getRecordsCache.start(sequenceNumber, initialPosition);
@ -384,7 +444,7 @@ public class PrefetchRecordsPublisherTest {
// to the subscriber.
GetRecordsResponse response = GetRecordsResponse.builder().records(
Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber("123").build())
.build();
.nextShardIterator(NEXT_SHARD_ITERATOR).build();
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(response);
getRecordsCache.start(sequenceNumber, initialPosition);

View file

@ -45,6 +45,7 @@ public class ProcessRecordsInputMatcher extends TypeSafeDiagnosingMatcher<Proces
matchers.put("millisBehindLatest",
nullOrEquals(template.millisBehindLatest(), ProcessRecordsInput::millisBehindLatest));
matchers.put("records", nullOrEquals(template.records(), ProcessRecordsInput::records));
matchers.put("childShards", nullOrEquals(template.childShards(), ProcessRecordsInput::childShards));
this.template = template;
}