Merge pull request #10 from ychunxue/ltr_base_ShardEnd

ShardEnd shard sync with Child Shards
This commit is contained in:
ychunxue 2020-04-09 10:50:44 -07:00 committed by GitHub
commit b1a3d215d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 467 additions and 166 deletions

View file

@ -39,6 +39,7 @@ import org.apache.commons.lang3.StringUtils;
import lombok.NonNull; import lombok.NonNull;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.services.kinesis.model.ShardFilter; import software.amazon.awssdk.services.kinesis.model.ShardFilter;
import software.amazon.awssdk.services.kinesis.model.ShardFilterType; import software.amazon.awssdk.services.kinesis.model.ShardFilterType;
@ -763,6 +764,41 @@ public class HierarchicalShardSyncer {
} }
} }
public synchronized Lease createLeaseForChildShard(final ChildShard childShard, final StreamIdentifier streamIdentifier) throws InvalidStateException {
final MultiStreamArgs multiStreamArgs = new MultiStreamArgs(isMultiStreamMode, streamIdentifier);
return multiStreamArgs.isMultiStreamMode() ? newKCLMultiStreamLeaseForChildShard(childShard, streamIdentifier)
: newKCLLeaseForChildShard(childShard);
}
private static Lease newKCLLeaseForChildShard(final ChildShard childShard) throws InvalidStateException {
Lease newLease = new Lease();
newLease.leaseKey(childShard.shardId());
if (!CollectionUtils.isNullOrEmpty(childShard.parentShards())) {
newLease.parentShardIds(childShard.parentShards());
} else {
throw new InvalidStateException("Unable to populate new lease for child shard " + childShard.shardId() + "because parent shards cannot be found.");
}
newLease.checkpoint(ExtendedSequenceNumber.TRIM_HORIZON);
newLease.ownerSwitchesSinceCheckpoint(0L);
return newLease;
}
private static Lease newKCLMultiStreamLeaseForChildShard(final ChildShard childShard, final StreamIdentifier streamIdentifier) throws InvalidStateException {
MultiStreamLease newLease = new MultiStreamLease();
newLease.leaseKey(MultiStreamLease.getLeaseKey(streamIdentifier.serialize(), childShard.shardId()));
if (!CollectionUtils.isNullOrEmpty(childShard.parentShards())) {
newLease.parentShardIds(childShard.parentShards());
} else {
throw new InvalidStateException("Unable to populate new lease for child shard " + childShard.shardId() + "because parent shards cannot be found.");
}
newLease.checkpoint(ExtendedSequenceNumber.TRIM_HORIZON);
newLease.ownerSwitchesSinceCheckpoint(0L);
newLease.streamIdentifier(streamIdentifier.serialize());
newLease.shardId(childShard.shardId());
return newLease;
}
/** /**
* Helper method to create a new Lease POJO for a shard. * Helper method to create a new Lease POJO for a shard.
* Note: Package level access only for testing purposes * Note: Package level access only for testing purposes

View file

@ -0,0 +1,27 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package software.amazon.kinesis.leases.exceptions;
/**
* Exception type for all exceptions thrown by the customer implemented code.
*/
public class CustomerApplicationException extends Exception {
public CustomerApplicationException(Throwable e) { super(e);}
public CustomerApplicationException(String message, Throwable e) { super(message, e);}
public CustomerApplicationException(String message) { super(message);}
}

View file

@ -496,7 +496,8 @@ class ConsumerStates {
argument.taskBackoffTimeMillis(), argument.taskBackoffTimeMillis(),
argument.recordsPublisher(), argument.recordsPublisher(),
argument.hierarchicalShardSyncer(), argument.hierarchicalShardSyncer(),
argument.metricsFactory()); argument.metricsFactory(),
input == null ? null : input.childShards());
} }
@Override @Override

View file

@ -16,6 +16,7 @@ package software.amazon.kinesis.lifecycle;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -32,6 +33,7 @@ import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.exceptions.internal.BlockedOnParentShardException; import software.amazon.kinesis.exceptions.internal.BlockedOnParentShardException;
import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardInfo;
@ -86,6 +88,8 @@ public class ShardConsumer {
private final ShardConsumerSubscriber subscriber; private final ShardConsumerSubscriber subscriber;
private ProcessRecordsInput shardEndProcessRecordsInput;
@Deprecated @Deprecated
public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo, public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo,
Optional<Long> logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument, Optional<Long> logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument,
@ -148,6 +152,7 @@ public class ShardConsumer {
processData(input); processData(input);
if (taskOutcome == TaskOutcome.END_OF_SHARD) { if (taskOutcome == TaskOutcome.END_OF_SHARD) {
markForShutdown(ShutdownReason.SHARD_END); markForShutdown(ShutdownReason.SHARD_END);
shardEndProcessRecordsInput = input;
subscription.cancel(); subscription.cancel();
return; return;
} }
@ -305,7 +310,7 @@ public class ShardConsumer {
return true; return true;
} }
executeTask(null); executeTask(shardEndProcessRecordsInput);
return false; return false;
} }
}, executorService); }, executorService);

View file

@ -20,6 +20,7 @@ import lombok.NonNull;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.annotations.KinesisClientInternalApi;
@ -30,6 +31,10 @@ import software.amazon.kinesis.leases.LeaseCoordinator;
import software.amazon.kinesis.leases.ShardDetector; import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.leases.HierarchicalShardSyncer; import software.amazon.kinesis.leases.HierarchicalShardSyncer;
import software.amazon.kinesis.leases.exceptions.CustomerApplicationException;
import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
import software.amazon.kinesis.lifecycle.events.LeaseLostInput; import software.amazon.kinesis.lifecycle.events.LeaseLostInput;
import software.amazon.kinesis.lifecycle.events.ShardEndedInput; import software.amazon.kinesis.lifecycle.events.ShardEndedInput;
import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsFactory;
@ -66,8 +71,6 @@ public class ShutdownTask implements ConsumerTask {
@NonNull @NonNull
private final InitialPositionInStreamExtended initialPositionInStream; private final InitialPositionInStreamExtended initialPositionInStream;
private final boolean cleanupLeasesOfCompletedShards; private final boolean cleanupLeasesOfCompletedShards;
private final boolean garbageCollectLeases = false;
private final boolean isLeaseTableEmpty = false;
private final boolean ignoreUnexpectedChildShards; private final boolean ignoreUnexpectedChildShards;
@NonNull @NonNull
private final LeaseCoordinator leaseCoordinator; private final LeaseCoordinator leaseCoordinator;
@ -81,6 +84,8 @@ public class ShutdownTask implements ConsumerTask {
private final TaskType taskType = TaskType.SHUTDOWN; private final TaskType taskType = TaskType.SHUTDOWN;
private final List<ChildShard> childShards;
private static final Function<ShardInfo, String> shardInfoIdProvider = shardInfo -> shardInfo private static final Function<ShardInfo, String> shardInfoIdProvider = shardInfo -> shardInfo
.streamIdentifierSerOpt().map(s -> s + ":" + shardInfo.shardId()).orElse(shardInfo.shardId()); .streamIdentifierSerOpt().map(s -> s + ":" + shardInfo.shardId()).orElse(shardInfo.shardId());
/* /*
@ -95,85 +100,48 @@ public class ShutdownTask implements ConsumerTask {
final MetricsScope scope = MetricsUtil.createMetricsWithOperation(metricsFactory, SHUTDOWN_TASK_OPERATION); final MetricsScope scope = MetricsUtil.createMetricsWithOperation(metricsFactory, SHUTDOWN_TASK_OPERATION);
Exception exception; Exception exception;
boolean applicationException = false;
try { try {
try { try {
ShutdownReason localReason = reason; log.debug("Invoking shutdown() for shard {}, concurrencyToken {}. Shutdown reason: {}",
List<Shard> latestShards = null; shardInfoIdProvider.apply(shardInfo), shardInfo.concurrencyToken(), reason);
/*
* Revalidate if the current shard is closed before shutting down the shard consumer with reason SHARD_END
* If current shard is not closed, shut down the shard consumer with reason LEASE_LOST that allows
* active workers to contend for the lease of this shard.
*/
if (localReason == ShutdownReason.SHARD_END) {
latestShards = shardDetector.listShards();
//If latestShards is empty, should also shutdown the ShardConsumer without checkpoint with SHARD_END final long startTime = System.currentTimeMillis();
if (CollectionUtils.isNullOrEmpty(latestShards) || !isShardInContextParentOfAny(latestShards)) { if (reason == ShutdownReason.SHARD_END) {
localReason = ShutdownReason.LEASE_LOST; // Create new lease for the child shards if they don't exist.
dropLease(); if (!CollectionUtils.isNullOrEmpty(childShards)) {
log.info("Forcing the lease to be lost before shutting down the consumer for Shard: " + shardInfoIdProvider.apply(shardInfo)); createLeasesForChildShardsIfNotExist();
} } else {
log.warn("Shard {} no longer exists. Shutting down consumer with SHARD_END reason without creating leases for child shards.", shardInfoIdProvider.apply(shardInfo));
} }
// If we reached end of the shard, set sequence number to SHARD_END.
if (localReason == ShutdownReason.SHARD_END) {
recordProcessorCheckpointer recordProcessorCheckpointer
.sequenceNumberAtShardEnd(recordProcessorCheckpointer.largestPermittedCheckpointValue()); .sequenceNumberAtShardEnd(recordProcessorCheckpointer.largestPermittedCheckpointValue());
recordProcessorCheckpointer.largestPermittedCheckpointValue(ExtendedSequenceNumber.SHARD_END); recordProcessorCheckpointer.largestPermittedCheckpointValue(ExtendedSequenceNumber.SHARD_END);
// Call the shardRecordsProcessor to checkpoint with SHARD_END sequence number.
// The shardEnded is implemented by customer. We should validate if the SHARD_END checkpointing is successful after calling shardEnded.
throwOnApplicationException(() -> applicationCheckpointAndVerification(), scope, startTime);
} else {
throwOnApplicationException(() -> shardRecordProcessor.leaseLost(LeaseLostInput.builder().build()), scope, startTime);
} }
log.debug("Invoking shutdown() for shard {}, concurrencyToken {}. Shutdown reason: {}", log.debug("Shutting down retrieval strategy for shard {}.", shardInfoIdProvider.apply(shardInfo));
shardInfoIdProvider.apply(shardInfo), shardInfo.concurrencyToken(), localReason);
final ShutdownInput shutdownInput = ShutdownInput.builder().shutdownReason(localReason)
.checkpointer(recordProcessorCheckpointer).build();
final long startTime = System.currentTimeMillis();
try {
if (localReason == ShutdownReason.SHARD_END) {
shardRecordProcessor.shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
ExtendedSequenceNumber lastCheckpointValue = recordProcessorCheckpointer.lastCheckpointValue();
if (lastCheckpointValue == null
|| !lastCheckpointValue.equals(ExtendedSequenceNumber.SHARD_END)) {
throw new IllegalArgumentException("Application didn't checkpoint at end of shard "
+ shardInfoIdProvider.apply(shardInfo) + ". Application must checkpoint upon shard end. " +
"See ShardRecordProcessor.shardEnded javadocs for more information.");
}
} else {
shardRecordProcessor.leaseLost(LeaseLostInput.builder().build());
}
log.debug("Shutting down retrieval strategy.");
recordsPublisher.shutdown(); recordsPublisher.shutdown();
log.debug("Record processor completed shutdown() for shard {}", shardInfoIdProvider.apply(shardInfo)); log.debug("Record processor completed shutdown() for shard {}", shardInfoIdProvider.apply(shardInfo));
} catch (Exception e) {
applicationException = true;
throw e;
} finally {
MetricsUtil.addLatency(scope, RECORD_PROCESSOR_SHUTDOWN_METRIC, startTime, MetricsLevel.SUMMARY);
}
if (localReason == ShutdownReason.SHARD_END) {
log.debug("Looking for child shards of shard {}", shardInfoIdProvider.apply(shardInfo));
// create leases for the child shards
hierarchicalShardSyncer.checkAndCreateLeaseForNewShards(shardDetector, leaseCoordinator.leaseRefresher(),
initialPositionInStream, latestShards, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, scope, garbageCollectLeases,
isLeaseTableEmpty);
log.debug("Finished checking for child shards of shard {}", shardInfoIdProvider.apply(shardInfo));
}
return new TaskResult(null); return new TaskResult(null);
} catch (Exception e) { } catch (Exception e) {
if (applicationException) { if (e instanceof CustomerApplicationException) {
log.error("Application exception. ", e); log.error("Shard {}: Application exception. ", shardInfoIdProvider.apply(shardInfo), e);
} else { } else {
log.error("Caught exception: ", e); log.error("Shard {}: Caught exception: ", shardInfoIdProvider.apply(shardInfo), e);
} }
exception = e; exception = e;
// backoff if we encounter an exception. // backoff if we encounter an exception.
try { try {
Thread.sleep(this.backoffTimeMillis); Thread.sleep(this.backoffTimeMillis);
} catch (InterruptedException ie) { } catch (InterruptedException ie) {
log.debug("Interrupted sleep", ie); log.debug("Shard {}: Interrupted sleep", shardInfoIdProvider.apply(shardInfo), ie);
} }
} }
} finally { } finally {
@ -181,7 +149,37 @@ public class ShutdownTask implements ConsumerTask {
} }
return new TaskResult(exception); return new TaskResult(exception);
}
private void applicationCheckpointAndVerification() {
shardRecordProcessor.shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
final ExtendedSequenceNumber lastCheckpointValue = recordProcessorCheckpointer.lastCheckpointValue();
if (lastCheckpointValue == null
|| !lastCheckpointValue.equals(ExtendedSequenceNumber.SHARD_END)) {
throw new IllegalArgumentException("Application didn't checkpoint at end of shard "
+ shardInfoIdProvider.apply(shardInfo) + ". Application must checkpoint upon shard end. " +
"See ShardRecordProcessor.shardEnded javadocs for more information.");
}
}
private void throwOnApplicationException(Runnable action, MetricsScope metricsScope, final long startTime) throws CustomerApplicationException {
try {
action.run();
} catch (Exception e) {
throw new CustomerApplicationException("Customer application throws exception for shard " + shardInfoIdProvider.apply(shardInfo) +": ", e);
} finally {
MetricsUtil.addLatency(metricsScope, RECORD_PROCESSOR_SHUTDOWN_METRIC, startTime, MetricsLevel.SUMMARY);
}
}
private void createLeasesForChildShardsIfNotExist()
throws DependencyException, InvalidStateException, ProvisionedThroughputException {
for(ChildShard childShard : childShards) {
if(leaseCoordinator.getCurrentlyHeldLease(childShard.shardId()) == null) {
final Lease leaseToCreate = hierarchicalShardSyncer.createLeaseForChildShard(childShard, shardDetector.streamIdentifier());
leaseCoordinator.leaseRefresher().createLeaseIfNotExists(leaseToCreate);
}
}
} }
/* /*

View file

@ -23,6 +23,7 @@ import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import lombok.ToString; import lombok.ToString;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.kinesis.processor.ShardRecordProcessor; import software.amazon.kinesis.processor.ShardRecordProcessor;
import software.amazon.kinesis.processor.RecordProcessorCheckpointer; import software.amazon.kinesis.processor.RecordProcessorCheckpointer;
import software.amazon.kinesis.retrieval.KinesisClientRecord; import software.amazon.kinesis.retrieval.KinesisClientRecord;
@ -66,6 +67,11 @@ public class ProcessRecordsInput {
* This value does not include the {@link #timeSpentInCache()}. * This value does not include the {@link #timeSpentInCache()}.
*/ */
private Long millisBehindLatest; private Long millisBehindLatest;
/**
* A list of child shards if the current GetRecords request reached the shard end.
* If not at the shard end, this should be an empty list.
*/
private List<ChildShard> childShards;
/** /**
* How long the records spent waiting to be dispatched to the {@link ShardRecordProcessor} * How long the records spent waiting to be dispatched to the {@link ShardRecordProcessor}

View file

@ -33,11 +33,13 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.Either; import software.amazon.awssdk.utils.Either;
import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.KinesisRequestsBuilder; import software.amazon.kinesis.common.KinesisRequestsBuilder;
import software.amazon.kinesis.common.RequestDetails; import software.amazon.kinesis.common.RequestDetails;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.BatchUniqueIdentifier; import software.amazon.kinesis.retrieval.BatchUniqueIdentifier;
import software.amazon.kinesis.retrieval.IteratorBuilder; import software.amazon.kinesis.retrieval.IteratorBuilder;
@ -398,7 +400,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
// The ack received for this onNext event will be ignored by the publisher as the global flow object should // The ack received for this onNext event will be ignored by the publisher as the global flow object should
// be either null or renewed when the ack's flow identifier is evaluated. // be either null or renewed when the ack's flow identifier is evaluated.
FanoutRecordsRetrieved response = new FanoutRecordsRetrieved( FanoutRecordsRetrieved response = new FanoutRecordsRetrieved(
ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).build(), null, ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).childShards(Collections.emptyList()).build(), null,
triggeringFlow != null ? triggeringFlow.getSubscribeToShardId() : shardId + "-no-flow-found"); triggeringFlow != null ? triggeringFlow.getSubscribeToShardId() : shardId + "-no-flow-found");
subscriber.onNext(response); subscriber.onNext(response);
subscriber.onComplete(); subscriber.onComplete();
@ -477,15 +479,28 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
return; return;
} }
try {
// If recordBatchEvent is not valid event, RuntimeException will be thrown here and trigger the errorOccurred call.
// Since the triggeringFlow is active flow, it will then trigger the handleFlowError call.
// Since the exception is not ResourceNotFoundException, it will trigger onError in the ShardConsumerSubscriber.
// The ShardConsumerSubscriber will finally cancel the subscription.
if (!isValidEvent(recordBatchEvent)) {
throw new InvalidStateException("RecordBatchEvent for flow " + triggeringFlow.toString() + " is invalid."
+ " event.continuationSequenceNumber: " + recordBatchEvent.continuationSequenceNumber()
+ ". event.childShards: " + recordBatchEvent.childShards());
}
List<KinesisClientRecord> records = recordBatchEvent.records().stream().map(KinesisClientRecord::fromRecord) List<KinesisClientRecord> records = recordBatchEvent.records().stream().map(KinesisClientRecord::fromRecord)
.collect(Collectors.toList()); .collect(Collectors.toList());
ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now()) ProcessRecordsInput input = ProcessRecordsInput.builder()
.cacheEntryTime(Instant.now())
.millisBehindLatest(recordBatchEvent.millisBehindLatest()) .millisBehindLatest(recordBatchEvent.millisBehindLatest())
.isAtShardEnd(recordBatchEvent.continuationSequenceNumber() == null).records(records).build(); .isAtShardEnd(recordBatchEvent.continuationSequenceNumber() == null)
.records(records)
.childShards(recordBatchEvent.childShards())
.build();
FanoutRecordsRetrieved recordsRetrieved = new FanoutRecordsRetrieved(input, FanoutRecordsRetrieved recordsRetrieved = new FanoutRecordsRetrieved(input,
recordBatchEvent.continuationSequenceNumber(), triggeringFlow.subscribeToShardId); recordBatchEvent.continuationSequenceNumber(), triggeringFlow.subscribeToShardId);
try {
bufferCurrentEventAndScheduleIfRequired(recordsRetrieved, triggeringFlow); bufferCurrentEventAndScheduleIfRequired(recordsRetrieved, triggeringFlow);
} catch (Throwable t) { } catch (Throwable t) {
log.warn("{}: Unable to buffer or schedule onNext for subscriber. Failing publisher." + log.warn("{}: Unable to buffer or schedule onNext for subscriber. Failing publisher." +
@ -495,6 +510,11 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
} }
} }
private boolean isValidEvent(SubscribeToShardEvent event) {
return event.continuationSequenceNumber() == null ? !CollectionUtils.isNullOrEmpty(event.childShards())
: event.childShards() != null && event.childShards().isEmpty();
}
private void updateAvailableQueueSpaceAndRequestUpstream(RecordFlow triggeringFlow) { private void updateAvailableQueueSpaceAndRequestUpstream(RecordFlow triggeringFlow) {
if (availableQueueSpace <= 0) { if (availableQueueSpace <= 0) {
log.debug( log.debug(

View file

@ -67,6 +67,7 @@ public class BlockingRecordsPublisher implements RecordsPublisher {
return ProcessRecordsInput.builder() return ProcessRecordsInput.builder()
.records(records) .records(records)
.millisBehindLatest(getRecordsResult.millisBehindLatest()) .millisBehindLatest(getRecordsResult.millisBehindLatest())
.childShards(getRecordsResult.childShards())
.build(); .build();
} }

View file

@ -36,6 +36,7 @@ import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse; import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse;
import software.amazon.awssdk.services.kinesis.model.KinesisException; import software.amazon.awssdk.services.kinesis.model.KinesisException;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.common.FutureUtils; import software.amazon.kinesis.common.FutureUtils;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
@ -133,8 +134,12 @@ public class KinesisDataFetcher {
final DataFetcherResult TERMINAL_RESULT = new DataFetcherResult() { final DataFetcherResult TERMINAL_RESULT = new DataFetcherResult() {
@Override @Override
public GetRecordsResponse getResult() { public GetRecordsResponse getResult() {
return GetRecordsResponse.builder().millisBehindLatest(null).records(Collections.emptyList()) return GetRecordsResponse.builder()
.nextShardIterator(null).build(); .millisBehindLatest(null)
.records(Collections.emptyList())
.nextShardIterator(null)
.childShards(Collections.emptyList())
.build();
} }
@Override @Override
@ -281,6 +286,11 @@ public class KinesisDataFetcher {
try { try {
final GetRecordsResponse response = FutureUtils.resolveOrCancelFuture(kinesisClient.getRecords(request), final GetRecordsResponse response = FutureUtils.resolveOrCancelFuture(kinesisClient.getRecords(request),
maxFutureWait); maxFutureWait);
if (!isValidResponse(response)) {
throw new RetryableRetrievalException("GetRecords response is not valid for shard: " + streamAndShardId
+ ". nextShardIterator: " + response.nextShardIterator()
+ ". childShards: " + response.childShards() + ". Will retry GetRecords with the same nextIterator.");
}
success = true; success = true;
return response; return response;
} catch (ExecutionException e) { } catch (ExecutionException e) {
@ -298,6 +308,11 @@ public class KinesisDataFetcher {
} }
} }
private boolean isValidResponse(GetRecordsResponse response) {
return response.nextShardIterator() == null ? !CollectionUtils.isNullOrEmpty(response.childShards())
: response.childShards() != null && response.childShards().isEmpty();
}
private AWSExceptionManager createExceptionManager() { private AWSExceptionManager createExceptionManager() {
final AWSExceptionManager exceptionManager = new AWSExceptionManager(); final AWSExceptionManager exceptionManager = new AWSExceptionManager();
exceptionManager.add(ResourceNotFoundException.class, t -> t); exceptionManager.add(ResourceNotFoundException.class, t -> t);

View file

@ -162,7 +162,7 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
} else { } else {
log.info( log.info(
"{}: No record batch found while evicting from the prefetch queue. This indicates the prefetch buffer" "{}: No record batch found while evicting from the prefetch queue. This indicates the prefetch buffer"
+ "was reset.", streamAndShardId); + " was reset.", streamAndShardId);
} }
return result; return result;
} }
@ -437,6 +437,7 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
.millisBehindLatest(getRecordsResult.millisBehindLatest()) .millisBehindLatest(getRecordsResult.millisBehindLatest())
.cacheEntryTime(lastSuccessfulCall) .cacheEntryTime(lastSuccessfulCall)
.isAtShardEnd(getRecordsRetrievalStrategy.getDataFetcher().isShardEndReached()) .isAtShardEnd(getRecordsRetrievalStrategy.getDataFetcher().isShardEndReached())
.childShards(getRecordsResult.childShards())
.build(); .build();
PrefetchRecordsRetrieved recordsRetrieved = new PrefetchRecordsRetrieved(processRecordsInput, PrefetchRecordsRetrieved recordsRetrieved = new PrefetchRecordsRetrieved(processRecordsInput,

View file

@ -25,6 +25,8 @@ import static org.mockito.Mockito.when;
import static software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState; import static software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@ -33,13 +35,13 @@ import org.hamcrest.Description;
import org.hamcrest.Matcher; import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeDiagnosingMatcher; import org.hamcrest.TypeSafeDiagnosingMatcher;
import org.junit.Before; import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer; import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer;
import software.amazon.kinesis.common.InitialPositionInStream; import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
@ -49,6 +51,7 @@ import software.amazon.kinesis.leases.LeaseRefresher;
import software.amazon.kinesis.leases.ShardDetector; import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.leases.HierarchicalShardSyncer; import software.amazon.kinesis.leases.HierarchicalShardSyncer;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsFactory;
import software.amazon.kinesis.processor.Checkpointer; import software.amazon.kinesis.processor.Checkpointer;
@ -57,6 +60,7 @@ import software.amazon.kinesis.processor.ShardRecordProcessor;
import software.amazon.kinesis.retrieval.AggregatorUtil; import software.amazon.kinesis.retrieval.AggregatorUtil;
import software.amazon.kinesis.retrieval.RecordsPublisher; import software.amazon.kinesis.retrieval.RecordsPublisher;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class ConsumerStatesTest { public class ConsumerStatesTest {
private static final String STREAM_NAME = "TestStream"; private static final String STREAM_NAME = "TestStream";
@ -300,13 +304,27 @@ public class ConsumerStatesTest {
} }
// TODO: Fix this test
@Ignore
@Test @Test
public void shuttingDownStateTest() { public void shuttingDownStateTest() {
consumer.markForShutdown(ShutdownReason.SHARD_END); consumer.markForShutdown(ShutdownReason.SHARD_END);
ConsumerState state = ShardConsumerState.SHUTTING_DOWN.consumerState(); ConsumerState state = ShardConsumerState.SHUTTING_DOWN.consumerState();
ConsumerTask task = state.createTask(argument, consumer, null); List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add("shardId-000000000000");
ChildShard leftChild = ChildShard.builder()
.shardId("shardId-000000000001")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("shardId-000000000002")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
when(processRecordsInput.childShards()).thenReturn(childShards);
ConsumerTask task = state.createTask(argument, consumer, processRecordsInput);
assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, assertThat(task,
@ -315,8 +333,6 @@ public class ConsumerStatesTest {
equalTo(recordProcessorCheckpointer))); equalTo(recordProcessorCheckpointer)));
assertThat(task, shutdownTask(ShutdownReason.class, "reason", equalTo(reason))); assertThat(task, shutdownTask(ShutdownReason.class, "reason", equalTo(reason)));
assertThat(task, shutdownTask(LeaseCoordinator.class, "leaseCoordinator", equalTo(leaseCoordinator))); assertThat(task, shutdownTask(LeaseCoordinator.class, "leaseCoordinator", equalTo(leaseCoordinator)));
assertThat(task, shutdownTask(InitialPositionInStreamExtended.class, "initialPositionInStream",
equalTo(initialPositionInStream)));
assertThat(task, assertThat(task,
shutdownTask(Boolean.class, "cleanupLeasesOfCompletedShards", equalTo(cleanupLeasesOfCompletedShards))); shutdownTask(Boolean.class, "cleanupLeasesOfCompletedShards", equalTo(cleanupLeasesOfCompletedShards)));
assertThat(task, shutdownTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis))); assertThat(task, shutdownTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));

View file

@ -26,6 +26,7 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -33,9 +34,11 @@ import java.util.List;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Matchers;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.SequenceNumberRange; import software.amazon.awssdk.services.kinesis.model.SequenceNumberRange;
import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer; import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer;
@ -43,11 +46,16 @@ import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.exceptions.internal.KinesisClientLibIOException; import software.amazon.kinesis.exceptions.internal.KinesisClientLibIOException;
import software.amazon.kinesis.leases.HierarchicalShardSyncer; import software.amazon.kinesis.leases.HierarchicalShardSyncer;
import software.amazon.kinesis.leases.Lease;
import software.amazon.kinesis.leases.LeaseCoordinator; import software.amazon.kinesis.leases.LeaseCoordinator;
import software.amazon.kinesis.leases.LeaseRefresher; import software.amazon.kinesis.leases.LeaseRefresher;
import software.amazon.kinesis.leases.ShardDetector; import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.leases.ShardObjectHelper; import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.leases.exceptions.CustomerApplicationException;
import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
import software.amazon.kinesis.lifecycle.events.LeaseLostInput; import software.amazon.kinesis.lifecycle.events.LeaseLostInput;
import software.amazon.kinesis.lifecycle.events.ShardEndedInput; import software.amazon.kinesis.lifecycle.events.ShardEndedInput;
import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsFactory;
@ -104,7 +112,7 @@ public class ShutdownTaskTest {
task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer, task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards, SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher,
hierarchicalShardSyncer, NULL_METRICS_FACTORY); hierarchicalShardSyncer, NULL_METRICS_FACTORY, constructChildShards());
} }
/** /**
@ -113,12 +121,12 @@ public class ShutdownTaskTest {
*/ */
@Test @Test
public final void testCallWhenApplicationDoesNotCheckpoint() { public final void testCallWhenApplicationDoesNotCheckpoint() {
when(shardDetector.listShards()).thenReturn(constructShardListGraphA());
when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(new ExtendedSequenceNumber("3298")); when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(new ExtendedSequenceNumber("3298"));
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
final TaskResult result = task.call(); final TaskResult result = task.call();
assertNotNull(result.getException()); assertNotNull(result.getException());
assertTrue(result.getException() instanceof IllegalArgumentException); assertTrue(result.getException() instanceof CustomerApplicationException);
} }
/** /**
@ -126,28 +134,18 @@ public class ShutdownTaskTest {
* This test is for the scenario that checkAndCreateLeaseForNewShards throws an exception. * This test is for the scenario that checkAndCreateLeaseForNewShards throws an exception.
*/ */
@Test @Test
public final void testCallWhenSyncingShardsThrows() throws Exception { public final void testCallWhenCreatingNewLeasesThrows() throws Exception {
final boolean garbageCollectLeases = false;
final boolean isLeaseTableEmpty = false;
List<Shard> latestShards = constructShardListGraphA();
when(shardDetector.listShards()).thenReturn(latestShards);
when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END); when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END);
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher); when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
when(leaseRefresher.createLeaseIfNotExists(Matchers.any(Lease.class))).thenThrow(new KinesisClientLibIOException("KinesisClientLibIOException"));
doAnswer((invocation) -> {
throw new KinesisClientLibIOException("KinesisClientLibIOException");
}).when(hierarchicalShardSyncer)
.checkAndCreateLeaseForNewShards(shardDetector, leaseRefresher, INITIAL_POSITION_TRIM_HORIZON,
latestShards, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards,
NULL_METRICS_FACTORY.createMetrics(), garbageCollectLeases, isLeaseTableEmpty);
final TaskResult result = task.call(); final TaskResult result = task.call();
assertNotNull(result.getException()); assertNotNull(result.getException());
assertTrue(result.getException() instanceof KinesisClientLibIOException); assertTrue(result.getException() instanceof KinesisClientLibIOException);
verify(recordsPublisher).shutdown(); verify(recordsPublisher, never()).shutdown();
verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build()); verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build());
verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class));
} }
/** /**
@ -155,24 +153,24 @@ public class ShutdownTaskTest {
* This test is for the scenario that ShutdownTask is created for ShardConsumer reaching the Shard End. * This test is for the scenario that ShutdownTask is created for ShardConsumer reaching the Shard End.
*/ */
@Test @Test
public final void testCallWhenTrueShardEnd() { public final void testCallWhenTrueShardEnd() throws DependencyException, InvalidStateException, ProvisionedThroughputException {
shardInfo = new ShardInfo("shardId-0", concurrencyToken, Collections.emptySet(), shardInfo = new ShardInfo("shardId-0", concurrencyToken, Collections.emptySet(),
ExtendedSequenceNumber.LATEST); ExtendedSequenceNumber.LATEST);
task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer, task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards, SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher,
hierarchicalShardSyncer, NULL_METRICS_FACTORY); hierarchicalShardSyncer, NULL_METRICS_FACTORY, constructChildShards());
when(shardDetector.listShards()).thenReturn(constructShardListGraphA());
when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END); when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END);
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
final TaskResult result = task.call(); final TaskResult result = task.call();
assertNull(result.getException()); assertNull(result.getException());
verify(recordsPublisher).shutdown(); verify(recordsPublisher).shutdown();
verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build()); verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build());
verify(shardDetector, times(1)).listShards(); verify(leaseRefresher, times(2)).createLeaseIfNotExists(Matchers.any(Lease.class));
verify(leaseCoordinator, never()).getAssignments(); verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class));
} }
/** /**
@ -180,23 +178,25 @@ public class ShutdownTaskTest {
* This test is for the scenario that a ShutdownTask is created for detecting a false Shard End. * This test is for the scenario that a ShutdownTask is created for detecting a false Shard End.
*/ */
@Test @Test
public final void testCallWhenFalseShardEnd() { public final void testCallWhenShardNotFound() throws DependencyException, InvalidStateException, ProvisionedThroughputException {
shardInfo = new ShardInfo("shardId-4", concurrencyToken, Collections.emptySet(), shardInfo = new ShardInfo("shardId-4", concurrencyToken, Collections.emptySet(),
ExtendedSequenceNumber.LATEST); ExtendedSequenceNumber.LATEST);
task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer, task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards, SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher,
hierarchicalShardSyncer, NULL_METRICS_FACTORY); hierarchicalShardSyncer, NULL_METRICS_FACTORY, new ArrayList<>());
when(shardDetector.listShards()).thenReturn(constructShardListGraphA()); when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END);
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);
final TaskResult result = task.call(); final TaskResult result = task.call();
assertNull(result.getException()); assertNull(result.getException());
verify(recordsPublisher).shutdown(); verify(recordsPublisher).shutdown();
verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor).leaseLost(LeaseLostInput.builder().build()); verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build());
verify(shardDetector, times(1)).listShards(); verify(leaseCoordinator, never()).getCurrentlyHeldLease(shardInfo.shardId());
verify(leaseCoordinator).getCurrentlyHeldLease(shardInfo.shardId()); verify(leaseRefresher, never()).createLeaseIfNotExists(Matchers.any(Lease.class));
verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class));
} }
/** /**
@ -204,23 +204,22 @@ public class ShutdownTaskTest {
* This test is for the scenario that a ShutdownTask is created for the ShardConsumer losing the lease. * This test is for the scenario that a ShutdownTask is created for the ShardConsumer losing the lease.
*/ */
@Test @Test
public final void testCallWhenLeaseLost() { public final void testCallWhenLeaseLost() throws DependencyException, InvalidStateException, ProvisionedThroughputException {
shardInfo = new ShardInfo("shardId-4", concurrencyToken, Collections.emptySet(), shardInfo = new ShardInfo("shardId-4", concurrencyToken, Collections.emptySet(),
ExtendedSequenceNumber.LATEST); ExtendedSequenceNumber.LATEST);
task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer, task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
LEASE_LOST_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards, LEASE_LOST_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher,
hierarchicalShardSyncer, NULL_METRICS_FACTORY); hierarchicalShardSyncer, NULL_METRICS_FACTORY, new ArrayList<>());
when(shardDetector.listShards()).thenReturn(constructShardListGraphA());
final TaskResult result = task.call(); final TaskResult result = task.call();
assertNull(result.getException()); assertNull(result.getException());
verify(recordsPublisher).shutdown(); verify(recordsPublisher).shutdown();
verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
verify(shardRecordProcessor).leaseLost(LeaseLostInput.builder().build()); verify(shardRecordProcessor).leaseLost(LeaseLostInput.builder().build());
verify(shardDetector, never()).listShards();
verify(leaseCoordinator, never()).getAssignments(); verify(leaseCoordinator, never()).getAssignments();
verify(leaseRefresher, never()).createLeaseIfNotExists(Matchers.any(Lease.class));
verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class));
} }
/** /**
@ -231,45 +230,23 @@ public class ShutdownTaskTest {
assertEquals(TaskType.SHUTDOWN, task.taskType()); assertEquals(TaskType.SHUTDOWN, task.taskType());
} }
private List<ChildShard> constructChildShards() {
/* List<ChildShard> childShards = new ArrayList<>();
* Helper method to construct a shard list for graph A. Graph A is defined below. Shard structure (y-axis is List<String> parentShards = new ArrayList<>();
* epochs): 0 1 2 3 4 5 - shards till parentShards.add(shardId);
* \ / \ / | | ChildShard leftChild = ChildShard.builder()
* 6 7 4 5 - shards from epoch 103 - 205 .shardId("ShardId-1")
* \ / | /\ .parentShards(parentShards)
* 8 4 9 10 - shards from epoch 206 (open - no ending sequenceNumber) .hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
*/ .build();
private List<Shard> constructShardListGraphA() { ChildShard rightChild = ChildShard.builder()
final SequenceNumberRange range0 = ShardObjectHelper.newSequenceNumberRange("11", "102"); .shardId("ShardId-2")
final SequenceNumberRange range1 = ShardObjectHelper.newSequenceNumberRange("11", null); .parentShards(parentShards)
final SequenceNumberRange range2 = ShardObjectHelper.newSequenceNumberRange("11", "205"); .hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
final SequenceNumberRange range3 = ShardObjectHelper.newSequenceNumberRange("103", "205"); .build();
final SequenceNumberRange range4 = ShardObjectHelper.newSequenceNumberRange("206", null); childShards.add(leftChild);
childShards.add(rightChild);
return Arrays.asList( return childShards;
ShardObjectHelper.newShard("shardId-0", null, null, range0,
ShardObjectHelper.newHashKeyRange("0", "99")),
ShardObjectHelper.newShard("shardId-1", null, null, range0,
ShardObjectHelper.newHashKeyRange("100", "199")),
ShardObjectHelper.newShard("shardId-2", null, null, range0,
ShardObjectHelper.newHashKeyRange("200", "299")),
ShardObjectHelper.newShard("shardId-3", null, null, range0,
ShardObjectHelper.newHashKeyRange("300", "399")),
ShardObjectHelper.newShard("shardId-4", null, null, range1,
ShardObjectHelper.newHashKeyRange("400", "499")),
ShardObjectHelper.newShard("shardId-5", null, null, range2,
ShardObjectHelper.newHashKeyRange("500", ShardObjectHelper.MAX_HASH_KEY)),
ShardObjectHelper.newShard("shardId-6", "shardId-0", "shardId-1", range3,
ShardObjectHelper.newHashKeyRange("0", "199")),
ShardObjectHelper.newShard("shardId-7", "shardId-2", "shardId-3", range3,
ShardObjectHelper.newHashKeyRange("200", "399")),
ShardObjectHelper.newShard("shardId-8", "shardId-6", "shardId-7", range4,
ShardObjectHelper.newHashKeyRange("0", "399")),
ShardObjectHelper.newShard("shardId-9", "shardId-5", null, range4,
ShardObjectHelper.newHashKeyRange("500", "799")),
ShardObjectHelper.newShard("shardId-10", null, "shardId-5", range4,
ShardObjectHelper.newHashKeyRange("800", ShardObjectHelper.MAX_HASH_KEY)));
} }
} }

View file

@ -26,6 +26,7 @@ import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;
@ -35,6 +36,7 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.kinesis.common.InitialPositionInStream; import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber; import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.BatchUniqueIdentifier; import software.amazon.kinesis.retrieval.BatchUniqueIdentifier;
@ -47,6 +49,7 @@ import software.amazon.kinesis.utils.SubscribeToShardRequestMatcher;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@ -89,6 +92,7 @@ public class FanOutRecordsPublisherTest {
private static final String SHARD_ID = "Shard-001"; private static final String SHARD_ID = "Shard-001";
private static final String CONSUMER_ARN = "arn:consumer"; private static final String CONSUMER_ARN = "arn:consumer";
private static final String CONTINUATION_SEQUENCE_NUMBER = "continuationSequenceNumber";
@Mock @Mock
private KinesisAsyncClient kinesisClient; private KinesisAsyncClient kinesisClient;
@ -148,7 +152,12 @@ public class FanOutRecordsPublisherTest {
List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new) List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
.collect(Collectors.toList()); .collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build(); batchEvent = SubscribeToShardEvent.builder()
.millisBehindLatest(100L)
.records(records)
.continuationSequenceNumber("test")
.childShards(Collections.emptyList())
.build();
captor.getValue().onNext(batchEvent); captor.getValue().onNext(batchEvent);
captor.getValue().onNext(batchEvent); captor.getValue().onNext(batchEvent);
@ -166,6 +175,73 @@ public class FanOutRecordsPublisherTest {
} }
@Test
public void InvalidEventTest() throws Exception {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordFlow.class);
doNothing().when(publisher).subscribe(captor.capture());
source.start(ExtendedSequenceNumber.LATEST,
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));
List<ProcessRecordsInput> receivedInput = new ArrayList<>();
source.subscribe(new ShardConsumerNotifyingSubscriber(new Subscriber<RecordsRetrieved>() {
Subscription subscription;
@Override public void onSubscribe(Subscription s) {
subscription = s;
subscription.request(1);
}
@Override public void onNext(RecordsRetrieved input) {
receivedInput.add(input.processRecordsInput());
subscription.request(1);
}
@Override public void onError(Throwable t) {
log.error("Caught throwable in subscriber", t);
fail("Caught throwable in subscriber");
}
@Override public void onComplete() {
fail("OnComplete called when not expected");
}
}, source));
verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
flowCaptor.getValue().onEventStream(publisher);
captor.getValue().onSubscribe(subscription);
List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());
List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
.collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).continuationSequenceNumber(CONTINUATION_SEQUENCE_NUMBER).build();
SubscribeToShardEvent invalidEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).childShards(Collections.emptyList()).build();
captor.getValue().onNext(batchEvent);
captor.getValue().onNext(invalidEvent);
captor.getValue().onNext(batchEvent);
// When the second request failed with invalid event, it should stop sending requests and cancel the flow.
verify(subscription, times(2)).request(1);
assertThat(receivedInput.size(), equalTo(1));
receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> {
assertThat(clientRecordsList.size(), equalTo(matchers.size()));
for (int i = 0; i < clientRecordsList.size(); ++i) {
assertThat(clientRecordsList.get(i), matchers.get(i));
}
});
}
@Test @Test
public void testIfAllEventsReceivedWhenNoTasksRejectedByExecutor() throws Exception { public void testIfAllEventsReceivedWhenNoTasksRejectedByExecutor() throws Exception {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
@ -225,7 +301,9 @@ public class FanOutRecordsPublisherTest {
SubscribeToShardEvent.builder() SubscribeToShardEvent.builder()
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum) .continuationSequenceNumber(contSeqNum)
.records(records).build()) .records(records)
.childShards(Collections.emptyList())
.build())
.forEach(batchEvent -> captor.getValue().onNext(batchEvent)); .forEach(batchEvent -> captor.getValue().onNext(batchEvent));
verify(subscription, times(4)).request(1); verify(subscription, times(4)).request(1);
@ -301,7 +379,9 @@ public class FanOutRecordsPublisherTest {
SubscribeToShardEvent.builder() SubscribeToShardEvent.builder()
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum) .continuationSequenceNumber(contSeqNum)
.records(records).build()) .records(records)
.childShards(Collections.emptyList())
.build())
.forEach(batchEvent -> captor.getValue().onNext(batchEvent)); .forEach(batchEvent -> captor.getValue().onNext(batchEvent));
verify(subscription, times(2)).request(1); verify(subscription, times(2)).request(1);
@ -334,6 +414,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "") .continuationSequenceNumber(contSeqNum + "")
.records(records) .records(records)
.childShards(Collections.emptyList())
.build()); .build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -436,6 +517,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "") .continuationSequenceNumber(contSeqNum + "")
.records(records) .records(records)
.childShards(Collections.emptyList())
.build()); .build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -536,13 +618,30 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "") .continuationSequenceNumber(contSeqNum + "")
.records(records) .records(records)
.childShards(Collections.emptyList())
.build()); .build());
List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add(SHARD_ID);
ChildShard leftChild = ChildShard.builder()
.shardId("Shard-002")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("Shard-003")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
Consumer<Integer> servicePublisherShardEndAction = contSeqNum -> captor.getValue().onNext( Consumer<Integer> servicePublisherShardEndAction = contSeqNum -> captor.getValue().onNext(
SubscribeToShardEvent.builder() SubscribeToShardEvent.builder()
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(null) .continuationSequenceNumber(null)
.records(records) .records(records)
.childShards(childShards)
.build()); .build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -648,6 +747,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "") .continuationSequenceNumber(contSeqNum + "")
.records(records) .records(records)
.childShards(Collections.emptyList())
.build()); .build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -750,6 +850,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "") .continuationSequenceNumber(contSeqNum + "")
.records(records) .records(records)
.childShards(Collections.emptyList())
.build()); .build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2);
@ -842,6 +943,7 @@ public class FanOutRecordsPublisherTest {
.millisBehindLatest(100L) .millisBehindLatest(100L)
.continuationSequenceNumber(contSeqNum + "") .continuationSequenceNumber(contSeqNum + "")
.records(records) .records(records)
.childShards(Collections.emptyList())
.build()); .build());
CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(1); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(1);
@ -1004,7 +1106,12 @@ public class FanOutRecordsPublisherTest {
List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new) List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
.collect(Collectors.toList()); .collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build(); batchEvent = SubscribeToShardEvent.builder()
.millisBehindLatest(100L)
.records(records)
.continuationSequenceNumber(CONTINUATION_SEQUENCE_NUMBER)
.childShards(Collections.emptyList())
.build();
captor.getValue().onNext(batchEvent); captor.getValue().onNext(batchEvent);
captor.getValue().onNext(batchEvent); captor.getValue().onNext(batchEvent);
@ -1098,7 +1205,7 @@ public class FanOutRecordsPublisherTest {
.collect(Collectors.toList()); .collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records) batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records)
.continuationSequenceNumber("3").build(); .continuationSequenceNumber("3").childShards(Collections.emptyList()).build();
captor.getValue().onNext(batchEvent); captor.getValue().onNext(batchEvent);
captor.getValue().onComplete(); captor.getValue().onComplete();
@ -1126,7 +1233,7 @@ public class FanOutRecordsPublisherTest {
.collect(Collectors.toList()); .collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(nextRecords) batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(nextRecords)
.continuationSequenceNumber("6").build(); .continuationSequenceNumber("6").childShards(Collections.emptyList()).build();
nextSubscribeCaptor.getValue().onNext(batchEvent); nextSubscribeCaptor.getValue().onNext(batchEvent);
verify(subscription, times(4)).request(1); verify(subscription, times(4)).request(1);

View file

@ -30,6 +30,7 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
@ -53,6 +54,7 @@ import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest; import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest;
import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
@ -65,6 +67,7 @@ import software.amazon.kinesis.checkpoint.SentinelCheckpoint;
import software.amazon.kinesis.common.InitialPositionInStream; import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.exceptions.KinesisClientLibException; import software.amazon.kinesis.exceptions.KinesisClientLibException;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsFactory;
import software.amazon.kinesis.metrics.NullMetricsFactory; import software.amazon.kinesis.metrics.NullMetricsFactory;
import software.amazon.kinesis.processor.Checkpointer; import software.amazon.kinesis.processor.Checkpointer;
@ -330,8 +333,31 @@ public class KinesisDataFetcherTest {
private CompletableFuture<GetRecordsResponse> makeGetRecordsResponse(String nextIterator, List<Record> records) private CompletableFuture<GetRecordsResponse> makeGetRecordsResponse(String nextIterator, List<Record> records)
throws InterruptedException, ExecutionException { throws InterruptedException, ExecutionException {
List<ChildShard> childShards = new ArrayList<>();
if(nextIterator == null) {
childShards = createChildShards();
}
return CompletableFuture.completedFuture(GetRecordsResponse.builder().nextShardIterator(nextIterator) return CompletableFuture.completedFuture(GetRecordsResponse.builder().nextShardIterator(nextIterator)
.records(CollectionUtils.isNullOrEmpty(records) ? Collections.emptyList() : records).build()); .records(CollectionUtils.isNullOrEmpty(records) ? Collections.emptyList() : records).childShards(childShards).build());
}
private List<ChildShard> createChildShards() {
List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add(SHARD_ID);
ChildShard leftChild = ChildShard.builder()
.shardId("Shard-2")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("Shard-3")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
return childShards;
} }
@Test @Test
@ -342,6 +368,7 @@ public class KinesisDataFetcherTest {
final String initialIterator = "InitialIterator"; final String initialIterator = "InitialIterator";
final String nextIterator1 = "NextIteratorOne"; final String nextIterator1 = "NextIteratorOne";
final String nextIterator2 = "NextIteratorTwo"; final String nextIterator2 = "NextIteratorTwo";
final String nextIterator3 = "NextIteratorThree";
final CompletableFuture<GetRecordsResponse> nonAdvancingResult1 = makeGetRecordsResponse(initialIterator, null); final CompletableFuture<GetRecordsResponse> nonAdvancingResult1 = makeGetRecordsResponse(initialIterator, null);
final CompletableFuture<GetRecordsResponse> nonAdvancingResult2 = makeGetRecordsResponse(nextIterator1, null); final CompletableFuture<GetRecordsResponse> nonAdvancingResult2 = makeGetRecordsResponse(nextIterator1, null);
final CompletableFuture<GetRecordsResponse> finalNonAdvancingResult = makeGetRecordsResponse(nextIterator2, final CompletableFuture<GetRecordsResponse> finalNonAdvancingResult = makeGetRecordsResponse(nextIterator2,

View file

@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
@ -48,6 +49,7 @@ import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -86,6 +88,7 @@ public class PrefetchRecordsPublisherIntegrationTest {
private String operation = "ProcessTask"; private String operation = "ProcessTask";
private String streamName = "streamName"; private String streamName = "streamName";
private String shardId = "shardId-000000000000"; private String shardId = "shardId-000000000000";
private String nextShardIterator = "testNextShardIterator";
@Mock @Mock
private KinesisAsyncClient kinesisClient; private KinesisAsyncClient kinesisClient;
@ -249,7 +252,7 @@ public class PrefetchRecordsPublisherIntegrationTest {
@Override @Override
public DataFetcherResult getRecords() { public DataFetcherResult getRecords() {
GetRecordsResponse getRecordsResult = GetRecordsResponse.builder().records(new ArrayList<>(records)).millisBehindLatest(1000L).build(); GetRecordsResponse getRecordsResult = GetRecordsResponse.builder().records(new ArrayList<>(records)).nextShardIterator(nextShardIterator).millisBehindLatest(1000L).build();
return new AdvancingResult(getRecordsResult); return new AdvancingResult(getRecordsResult);
} }

View file

@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
@ -71,11 +72,13 @@ import io.reactivex.Flowable;
import io.reactivex.schedulers.Schedulers; import io.reactivex.schedulers.Schedulers;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException;
import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.RequestDetails; import software.amazon.kinesis.common.RequestDetails;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.common.StreamIdentifier; import software.amazon.kinesis.common.StreamIdentifier;
import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber; import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
@ -99,6 +102,7 @@ public class PrefetchRecordsPublisherTest {
private static final int MAX_SIZE = 5; private static final int MAX_SIZE = 5;
private static final int MAX_RECORDS_COUNT = 15000; private static final int MAX_RECORDS_COUNT = 15000;
private static final long IDLE_MILLIS_BETWEEN_CALLS = 0L; private static final long IDLE_MILLIS_BETWEEN_CALLS = 0L;
private static final String NEXT_SHARD_ITERATOR = "testNextShardIterator";
@Mock @Mock
private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy; private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
@ -136,7 +140,7 @@ public class PrefetchRecordsPublisherTest {
"shardId"); "shardId");
spyQueue = spy(getRecordsCache.getPublisherSession().prefetchRecordsQueue()); spyQueue = spy(getRecordsCache.getPublisherSession().prefetchRecordsQueue());
records = spy(new ArrayList<>()); records = spy(new ArrayList<>());
getRecordsResponse = GetRecordsResponse.builder().records(records).build(); getRecordsResponse = GetRecordsResponse.builder().records(records).nextShardIterator(NEXT_SHARD_ITERATOR).childShards(new ArrayList<>()).build();
when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(getRecordsResponse); when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(getRecordsResponse);
} }
@ -155,11 +159,67 @@ public class PrefetchRecordsPublisherTest {
.processRecordsInput(); .processRecordsInput();
assertEquals(expectedRecords, result.records()); assertEquals(expectedRecords, result.records());
assertEquals(new ArrayList<>(), result.childShards());
verify(executorService).execute(any()); verify(executorService).execute(any());
verify(getRecordsRetrievalStrategy, atLeast(1)).getRecords(eq(MAX_RECORDS_PER_CALL)); verify(getRecordsRetrievalStrategy, atLeast(1)).getRecords(eq(MAX_RECORDS_PER_CALL));
} }
@Test
public void testGetRecordsWithInvalidResponse() {
record = Record.builder().data(createByteBufferWithSize(SIZE_512_KB)).build();
when(records.size()).thenReturn(1000);
GetRecordsResponse response = GetRecordsResponse.builder().records(records).build();
when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(response);
when(dataFetcher.isShardEndReached()).thenReturn(false);
getRecordsCache.start(sequenceNumber, initialPosition);
try {
ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L)
.processRecordsInput();
} catch (Exception e) {
assertEquals("No records found", e.getMessage());
}
}
@Test
public void testGetRecordsWithShardEnd() {
records = new ArrayList<>();
final List<KinesisClientRecord> expectedRecords = new ArrayList<>();
List<ChildShard> childShards = new ArrayList<>();
List<String> parentShards = new ArrayList<>();
parentShards.add("shardId");
ChildShard leftChild = ChildShard.builder()
.shardId("shardId-000000000001")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49"))
.build();
ChildShard rightChild = ChildShard.builder()
.shardId("shardId-000000000002")
.parentShards(parentShards)
.hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99"))
.build();
childShards.add(leftChild);
childShards.add(rightChild);
GetRecordsResponse response = GetRecordsResponse.builder().records(records).childShards(childShards).build();
when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(response);
when(dataFetcher.isShardEndReached()).thenReturn(true);
getRecordsCache.start(sequenceNumber, initialPosition);
ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L)
.processRecordsInput();
assertEquals(expectedRecords, result.records());
assertEquals(childShards, result.childShards());
assertTrue(result.isAtShardEnd());
}
// TODO: Broken test // TODO: Broken test
@Test @Test
@Ignore @Ignore
@ -270,7 +330,7 @@ public class PrefetchRecordsPublisherTest {
@Test @Test
public void testRetryableRetrievalExceptionContinues() { public void testRetryableRetrievalExceptionContinues() {
GetRecordsResponse response = GetRecordsResponse.builder().millisBehindLatest(100L).records(Collections.emptyList()).build(); GetRecordsResponse response = GetRecordsResponse.builder().millisBehindLatest(100L).records(Collections.emptyList()).nextShardIterator(NEXT_SHARD_ITERATOR).build();
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenThrow(new RetryableRetrievalException("Timeout", new TimeoutException("Timeout"))).thenReturn(response); when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenThrow(new RetryableRetrievalException("Timeout", new TimeoutException("Timeout"))).thenReturn(response);
getRecordsCache.start(sequenceNumber, initialPosition); getRecordsCache.start(sequenceNumber, initialPosition);
@ -293,7 +353,7 @@ public class PrefetchRecordsPublisherTest {
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenAnswer( i -> GetRecordsResponse.builder().records( when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenAnswer( i -> GetRecordsResponse.builder().records(
Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber(++sequenceNumberInResponse[0] + "").build()) Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber(++sequenceNumberInResponse[0] + "").build())
.build()); .nextShardIterator(NEXT_SHARD_ITERATOR).build());
getRecordsCache.start(sequenceNumber, initialPosition); getRecordsCache.start(sequenceNumber, initialPosition);
@ -384,7 +444,7 @@ public class PrefetchRecordsPublisherTest {
// to the subscriber. // to the subscriber.
GetRecordsResponse response = GetRecordsResponse.builder().records( GetRecordsResponse response = GetRecordsResponse.builder().records(
Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber("123").build()) Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber("123").build())
.build(); .nextShardIterator(NEXT_SHARD_ITERATOR).build();
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(response); when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(response);
getRecordsCache.start(sequenceNumber, initialPosition); getRecordsCache.start(sequenceNumber, initialPosition);

View file

@ -45,6 +45,7 @@ public class ProcessRecordsInputMatcher extends TypeSafeDiagnosingMatcher<Proces
matchers.put("millisBehindLatest", matchers.put("millisBehindLatest",
nullOrEquals(template.millisBehindLatest(), ProcessRecordsInput::millisBehindLatest)); nullOrEquals(template.millisBehindLatest(), ProcessRecordsInput::millisBehindLatest));
matchers.put("records", nullOrEquals(template.records(), ProcessRecordsInput::records)); matchers.put("records", nullOrEquals(template.records(), ProcessRecordsInput::records));
matchers.put("childShards", nullOrEquals(template.childShards(), ProcessRecordsInput::childShards));
this.template = template; this.template = template;
} }