From 2f9ce0ce4f7c83025f8e33376ef1b326789f1a83 Mon Sep 17 00:00:00 2001 From: Chunxue Yang Date: Thu, 5 Mar 2020 16:44:37 -0800 Subject: [PATCH] ShardEnd Shard Sync --- .../leases/HierarchicalShardSyncer.java | 13 ++ .../kinesis/lifecycle/ConsumerStates.java | 3 +- .../kinesis/lifecycle/ShardConsumer.java | 7 +- .../kinesis/lifecycle/ShutdownTask.java | 91 +++++++------ .../lifecycle/events/ProcessRecordsInput.java | 6 + .../fanout/FanOutRecordsPublisher.java | 38 ++++-- .../polling/BlockingRecordsPublisher.java | 1 + .../retrieval/polling/KinesisDataFetcher.java | 22 +++- .../polling/PrefetchRecordsPublisher.java | 3 +- .../kinesis/lifecycle/ConsumerStatesTest.java | 29 ++++- .../kinesis/lifecycle/ShutdownTaskTest.java | 122 +++++++----------- .../fanout/FanOutRecordsPublisherTest.java | 119 ++++++++++++++++- .../polling/KinesisDataFetcherTest.java | 29 ++++- ...efetchRecordsPublisherIntegrationTest.java | 5 +- .../polling/PrefetchRecordsPublisherTest.java | 68 +++++++++- .../utils/ProcessRecordsInputMatcher.java | 1 + 16 files changed, 405 insertions(+), 152 deletions(-) diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/leases/HierarchicalShardSyncer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/leases/HierarchicalShardSyncer.java index ecd64952..34c17bdf 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/leases/HierarchicalShardSyncer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/leases/HierarchicalShardSyncer.java @@ -39,6 +39,7 @@ import org.apache.commons.lang3.StringUtils; import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.services.kinesis.model.ShardFilter; import software.amazon.awssdk.services.kinesis.model.ShardFilterType; @@ -164,6 +165,18 @@ public class HierarchicalShardSyncer { } } + public synchronized Lease createLeaseForChildShard(ChildShard childShard) throws InvalidStateException { + Lease newLease = new Lease(); + newLease.leaseKey(childShard.shardId()); + if (!CollectionUtils.isNullOrEmpty(childShard.parentShards())) { + newLease.parentShardIds(childShard.parentShards()); + } else { + throw new InvalidStateException("Unable to populate new lease for child shard " + childShard.shardId() + "because parent shards cannot be found."); + } + newLease.ownerSwitchesSinceCheckpoint(0L); + return newLease; + } + // CHECKSTYLE:ON CyclomaticComplexity /** Note: This method has package level access solely for testing purposes. diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java index bb1788b2..58e31985 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java @@ -496,7 +496,8 @@ class ConsumerStates { argument.taskBackoffTimeMillis(), argument.recordsPublisher(), argument.hierarchicalShardSyncer(), - argument.metricsFactory()); + argument.metricsFactory(), + input == null ? null : input.childShards()); } @Override diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java index e34f2ea4..b6e7c068 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java @@ -16,6 +16,7 @@ package software.amazon.kinesis.lifecycle; import java.time.Duration; import java.time.Instant; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -32,6 +33,7 @@ import lombok.Getter; import lombok.NonNull; import lombok.experimental.Accessors; import lombok.extern.slf4j.Slf4j; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.exceptions.internal.BlockedOnParentShardException; import software.amazon.kinesis.leases.ShardInfo; @@ -86,6 +88,8 @@ public class ShardConsumer { private final ShardConsumerSubscriber subscriber; + private ProcessRecordsInput shardEndProcessRecordsInput; + @Deprecated public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo, Optional logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument, @@ -148,6 +152,7 @@ public class ShardConsumer { processData(input); if (taskOutcome == TaskOutcome.END_OF_SHARD) { markForShutdown(ShutdownReason.SHARD_END); + shardEndProcessRecordsInput = input; subscription.cancel(); return; } @@ -305,7 +310,7 @@ public class ShardConsumer { return true; } - executeTask(null); + executeTask(shardEndProcessRecordsInput); return false; } }, executorService); diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShutdownTask.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShutdownTask.java index 18e2be76..18a0af63 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShutdownTask.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShutdownTask.java @@ -20,6 +20,7 @@ import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.kinesis.annotations.KinesisClientInternalApi; @@ -30,7 +31,11 @@ import software.amazon.kinesis.leases.LeaseCoordinator; import software.amazon.kinesis.leases.ShardDetector; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.HierarchicalShardSyncer; +import software.amazon.kinesis.leases.exceptions.DependencyException; +import software.amazon.kinesis.leases.exceptions.InvalidStateException; +import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException; import software.amazon.kinesis.lifecycle.events.LeaseLostInput; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.ShardEndedInput; import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsScope; @@ -81,6 +86,8 @@ public class ShutdownTask implements ConsumerTask { private final TaskType taskType = TaskType.SHUTDOWN; + private final List childShards; + private static final Function shardInfoIdProvider = shardInfo -> shardInfo .streamIdentifierSerOpt().map(s -> s + ":" + shardInfo.shardId()).orElse(shardInfo.shardId()); /* @@ -99,67 +106,48 @@ public class ShutdownTask implements ConsumerTask { try { try { - ShutdownReason localReason = reason; - List latestShards = null; - /* - * Revalidate if the current shard is closed before shutting down the shard consumer with reason SHARD_END - * If current shard is not closed, shut down the shard consumer with reason LEASE_LOST that allows - * active workers to contend for the lease of this shard. - */ - if (localReason == ShutdownReason.SHARD_END) { - latestShards = shardDetector.listShards(); + log.debug("Invoking shutdown() for shard {}, concurrencyToken {}. Shutdown reason: {}", + shardInfoIdProvider.apply(shardInfo), shardInfo.concurrencyToken(), reason); - //If latestShards is empty, should also shutdown the ShardConsumer without checkpoint with SHARD_END - if (CollectionUtils.isNullOrEmpty(latestShards) || !isShardInContextParentOfAny(latestShards)) { - localReason = ShutdownReason.LEASE_LOST; - dropLease(); - log.info("Forcing the lease to be lost before shutting down the consumer for Shard: " + shardInfoIdProvider.apply(shardInfo)); + final long startTime = System.currentTimeMillis(); + if (reason == ShutdownReason.SHARD_END) { + // Create new lease for the child shards if they don't exist. + if (!CollectionUtils.isNullOrEmpty(childShards)) { + createLeasesForChildShardsIfNotExist(); } - } - // If we reached end of the shard, set sequence number to SHARD_END. - if (localReason == ShutdownReason.SHARD_END) { recordProcessorCheckpointer .sequenceNumberAtShardEnd(recordProcessorCheckpointer.largestPermittedCheckpointValue()); recordProcessorCheckpointer.largestPermittedCheckpointValue(ExtendedSequenceNumber.SHARD_END); - } - - log.debug("Invoking shutdown() for shard {}, concurrencyToken {}. Shutdown reason: {}", - shardInfoIdProvider.apply(shardInfo), shardInfo.concurrencyToken(), localReason); - final ShutdownInput shutdownInput = ShutdownInput.builder().shutdownReason(localReason) - .checkpointer(recordProcessorCheckpointer).build(); - final long startTime = System.currentTimeMillis(); - try { - if (localReason == ShutdownReason.SHARD_END) { + // Call the shardReocrdsProcessor to checkpoint with SHARD_END sequence number. + // The shardEnded is implemented by customer. We should validate if the Shard_End checkpointing is successful after calling shardEnded. + try { shardRecordProcessor.shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); - ExtendedSequenceNumber lastCheckpointValue = recordProcessorCheckpointer.lastCheckpointValue(); + final ExtendedSequenceNumber lastCheckpointValue = recordProcessorCheckpointer.lastCheckpointValue(); if (lastCheckpointValue == null || !lastCheckpointValue.equals(ExtendedSequenceNumber.SHARD_END)) { throw new IllegalArgumentException("Application didn't checkpoint at end of shard " - + shardInfoIdProvider.apply(shardInfo) + ". Application must checkpoint upon shard end. " + - "See ShardRecordProcessor.shardEnded javadocs for more information."); + + shardInfoIdProvider.apply(shardInfo) + ". Application must checkpoint upon shard end. " + + "See ShardRecordProcessor.shardEnded javadocs for more information."); } - } else { - shardRecordProcessor.leaseLost(LeaseLostInput.builder().build()); + } catch (Exception e) { + applicationException = true; + throw e; + } finally { + MetricsUtil.addLatency(scope, RECORD_PROCESSOR_SHUTDOWN_METRIC, startTime, MetricsLevel.SUMMARY); + } + } else { + try { + shardRecordProcessor.leaseLost(LeaseLostInput.builder().build()); + } catch (Exception e) { + applicationException = true; + throw e; } - log.debug("Shutting down retrieval strategy."); - recordsPublisher.shutdown(); - log.debug("Record processor completed shutdown() for shard {}", shardInfoIdProvider.apply(shardInfo)); - } catch (Exception e) { - applicationException = true; - throw e; - } finally { - MetricsUtil.addLatency(scope, RECORD_PROCESSOR_SHUTDOWN_METRIC, startTime, MetricsLevel.SUMMARY); } - if (localReason == ShutdownReason.SHARD_END) { - log.debug("Looking for child shards of shard {}", shardInfoIdProvider.apply(shardInfo)); - // create leases for the child shards - hierarchicalShardSyncer.checkAndCreateLeaseForNewShards(shardDetector, leaseCoordinator.leaseRefresher(), - initialPositionInStream, latestShards, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, scope, garbageCollectLeases, - isLeaseTableEmpty); - log.debug("Finished checking for child shards of shard {}", shardInfoIdProvider.apply(shardInfo)); - } + log.debug("Shutting down retrieval strategy."); + recordsPublisher.shutdown(); + log.debug("Record processor completed shutdown() for shard {}", shardInfoIdProvider.apply(shardInfo)); return new TaskResult(null); } catch (Exception e) { @@ -181,7 +169,16 @@ public class ShutdownTask implements ConsumerTask { } return new TaskResult(exception); + } + private void createLeasesForChildShardsIfNotExist() + throws DependencyException, InvalidStateException, ProvisionedThroughputException { + for(ChildShard childShard : childShards) { + if(leaseCoordinator.getCurrentlyHeldLease(shardInfo.shardId()) == null) { + final Lease leaseToCreate = hierarchicalShardSyncer.createLeaseForChildShard(childShard); + leaseCoordinator.leaseRefresher().createLeaseIfNotExists(leaseToCreate); + } + } } /* diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ProcessRecordsInput.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ProcessRecordsInput.java index b7dd4e05..3bfcd514 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ProcessRecordsInput.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ProcessRecordsInput.java @@ -23,6 +23,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; import lombok.experimental.Accessors; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.kinesis.processor.ShardRecordProcessor; import software.amazon.kinesis.processor.RecordProcessorCheckpointer; import software.amazon.kinesis.retrieval.KinesisClientRecord; @@ -66,6 +67,11 @@ public class ProcessRecordsInput { * This value does not include the {@link #timeSpentInCache()}. */ private Long millisBehindLatest; + /** + * A list of child shards if the current GetRecords request reached the shard end. + * If not at the shard end, this should be an empty list. + */ + private List childShards; /** * How long the records spent waiting to be dispatched to the {@link ShardRecordProcessor} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java index c24a3803..38075890 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java @@ -33,11 +33,13 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; +import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Either; import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.KinesisRequestsBuilder; import software.amazon.kinesis.common.RequestDetails; +import software.amazon.kinesis.leases.exceptions.InvalidStateException; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.retrieval.BatchUniqueIdentifier; import software.amazon.kinesis.retrieval.IteratorBuilder; @@ -398,7 +400,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher { // The ack received for this onNext event will be ignored by the publisher as the global flow object should // be either null or renewed when the ack's flow identifier is evaluated. FanoutRecordsRetrieved response = new FanoutRecordsRetrieved( - ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).build(), null, + ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).childShards(Collections.emptyList()).build(), null, triggeringFlow != null ? triggeringFlow.getSubscribeToShardId() : shardId + "-no-flow-found"); subscriber.onNext(response); subscriber.onComplete(); @@ -477,15 +479,28 @@ public class FanOutRecordsPublisher implements RecordsPublisher { return; } - List records = recordBatchEvent.records().stream().map(KinesisClientRecord::fromRecord) - .collect(Collectors.toList()); - ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now()) - .millisBehindLatest(recordBatchEvent.millisBehindLatest()) - .isAtShardEnd(recordBatchEvent.continuationSequenceNumber() == null).records(records).build(); - FanoutRecordsRetrieved recordsRetrieved = new FanoutRecordsRetrieved(input, - recordBatchEvent.continuationSequenceNumber(), triggeringFlow.subscribeToShardId); - try { + // If recordBatchEvent is not valid event, RuntimeException will be thrown here and trigger the errorOccurred call. + // Since the triggeringFlow is active flow, it will then trigger the handleFlowError call. + // Since the exception is not ResourceNotFoundException, it will trigger onError in the ShardConsumerSubscriber. + // The ShardConsumerSubscriber will finally cancel the subscription. + if (!isValidEvent(recordBatchEvent)) { + throw new InvalidStateException("RecordBatchEvent for flow " + triggeringFlow.toString() + " is invalid." + + " event.continuationSequenceNumber: " + recordBatchEvent.continuationSequenceNumber() + + ". event.childShards: " + recordBatchEvent.childShards()); + } + + List records = recordBatchEvent.records().stream().map(KinesisClientRecord::fromRecord) + .collect(Collectors.toList()); + ProcessRecordsInput input = ProcessRecordsInput.builder() + .cacheEntryTime(Instant.now()) + .millisBehindLatest(recordBatchEvent.millisBehindLatest()) + .isAtShardEnd(recordBatchEvent.continuationSequenceNumber() == null) + .records(records) + .childShards(recordBatchEvent.childShards()) + .build(); + FanoutRecordsRetrieved recordsRetrieved = new FanoutRecordsRetrieved(input, + recordBatchEvent.continuationSequenceNumber(), triggeringFlow.subscribeToShardId); bufferCurrentEventAndScheduleIfRequired(recordsRetrieved, triggeringFlow); } catch (Throwable t) { log.warn("{}: Unable to buffer or schedule onNext for subscriber. Failing publisher." + @@ -495,6 +510,11 @@ public class FanOutRecordsPublisher implements RecordsPublisher { } } + private boolean isValidEvent(SubscribeToShardEvent event) { + return event.continuationSequenceNumber() == null ? !CollectionUtils.isNullOrEmpty(event.childShards()) + : event.childShards() != null && event.childShards().isEmpty(); + } + private void updateAvailableQueueSpaceAndRequestUpstream(RecordFlow triggeringFlow) { if (availableQueueSpace <= 0) { log.debug( diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/BlockingRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/BlockingRecordsPublisher.java index 1e6462f5..33be11d4 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/BlockingRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/BlockingRecordsPublisher.java @@ -67,6 +67,7 @@ public class BlockingRecordsPublisher implements RecordsPublisher { return ProcessRecordsInput.builder() .records(records) .millisBehindLatest(getRecordsResult.millisBehindLatest()) + .childShards(getRecordsResult.childShards()) .build(); } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcher.java index 1605f941..dc25b20c 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcher.java @@ -36,6 +36,7 @@ import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse; import software.amazon.awssdk.services.kinesis.model.KinesisException; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; +import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.common.FutureUtils; import software.amazon.kinesis.common.InitialPositionInStreamExtended; @@ -120,7 +121,13 @@ public class KinesisDataFetcher { if (nextIterator != null) { try { - return new AdvancingResult(getRecords(nextIterator)); + GetRecordsResponse getRecordsResponse = getRecords(nextIterator); + while (!isValidResponse(getRecordsResponse)) { + log.error("{} : GetRecords response is not valid. nextShardIterator: {}. childShards: {}. Will retry GetRecords with the same nextIterator.", + shardId, getRecordsResponse.nextShardIterator(), getRecordsResponse.childShards()); + getRecordsResponse = getRecords(nextIterator); + } + return new AdvancingResult(getRecordsResponse); } catch (ResourceNotFoundException e) { log.info("Caught ResourceNotFoundException when fetching records for shard {}", streamAndShardId); return TERMINAL_RESULT; @@ -133,8 +140,12 @@ public class KinesisDataFetcher { final DataFetcherResult TERMINAL_RESULT = new DataFetcherResult() { @Override public GetRecordsResponse getResult() { - return GetRecordsResponse.builder().millisBehindLatest(null).records(Collections.emptyList()) - .nextShardIterator(null).build(); + return GetRecordsResponse.builder() + .millisBehindLatest(null) + .records(Collections.emptyList()) + .nextShardIterator(null) + .childShards(Collections.emptyList()) + .build(); } @Override @@ -177,6 +188,11 @@ public class KinesisDataFetcher { } } + private boolean isValidResponse(GetRecordsResponse response) { + return response.nextShardIterator() == null ? !CollectionUtils.isNullOrEmpty(response.childShards()) + : response.childShards() != null && response.childShards().isEmpty(); + } + /** * Initializes this KinesisDataFetcher's iterator based on the checkpointed sequence number. * @param initialCheckpoint Current checkpoint sequence number for this shard. diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java index ba8aa117..6e172f08 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java @@ -162,7 +162,7 @@ public class PrefetchRecordsPublisher implements RecordsPublisher { } else { log.info( "{}: No record batch found while evicting from the prefetch queue. This indicates the prefetch buffer" - + "was reset.", streamAndShardId); + + " was reset.", streamAndShardId); } return result; } @@ -437,6 +437,7 @@ public class PrefetchRecordsPublisher implements RecordsPublisher { .millisBehindLatest(getRecordsResult.millisBehindLatest()) .cacheEntryTime(lastSuccessfulCall) .isAtShardEnd(getRecordsRetrievalStrategy.getDataFetcher().isShardEndReached()) + .childShards(getRecordsResult.childShards()) .build(); PrefetchRecordsRetrieved recordsRetrieved = new PrefetchRecordsRetrieved(processRecordsInput, diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java index 5d8e302f..b4164f90 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ConsumerStatesTest.java @@ -25,6 +25,8 @@ import static org.mockito.Mockito.when; import static software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState; import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutorService; @@ -40,6 +42,7 @@ import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer; import software.amazon.kinesis.common.InitialPositionInStream; import software.amazon.kinesis.common.InitialPositionInStreamExtended; @@ -49,6 +52,7 @@ import software.amazon.kinesis.leases.LeaseRefresher; import software.amazon.kinesis.leases.ShardDetector; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.HierarchicalShardSyncer; +import software.amazon.kinesis.leases.ShardObjectHelper; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.processor.Checkpointer; @@ -57,6 +61,9 @@ import software.amazon.kinesis.processor.ShardRecordProcessor; import software.amazon.kinesis.retrieval.AggregatorUtil; import software.amazon.kinesis.retrieval.RecordsPublisher; +import javax.swing.*; +import javax.swing.text.AsyncBoxView; + @RunWith(MockitoJUnitRunner.class) public class ConsumerStatesTest { private static final String STREAM_NAME = "TestStream"; @@ -300,13 +307,27 @@ public class ConsumerStatesTest { } - // TODO: Fix this test - @Ignore @Test public void shuttingDownStateTest() { consumer.markForShutdown(ShutdownReason.SHARD_END); ConsumerState state = ShardConsumerState.SHUTTING_DOWN.consumerState(); - ConsumerTask task = state.createTask(argument, consumer, null); + List childShards = new ArrayList<>(); + List parentShards = new ArrayList<>(); + parentShards.add("shardId-000000000000"); + ChildShard leftChild = ChildShard.builder() + .shardId("shardId-000000000001") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49")) + .build(); + ChildShard rightChild = ChildShard.builder() + .shardId("shardId-000000000002") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99")) + .build(); + childShards.add(leftChild); + childShards.add(rightChild); + when(processRecordsInput.childShards()).thenReturn(childShards); + ConsumerTask task = state.createTask(argument, consumer, processRecordsInput); assertThat(task, shutdownTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); assertThat(task, @@ -315,8 +336,6 @@ public class ConsumerStatesTest { equalTo(recordProcessorCheckpointer))); assertThat(task, shutdownTask(ShutdownReason.class, "reason", equalTo(reason))); assertThat(task, shutdownTask(LeaseCoordinator.class, "leaseCoordinator", equalTo(leaseCoordinator))); - assertThat(task, shutdownTask(InitialPositionInStreamExtended.class, "initialPositionInStream", - equalTo(initialPositionInStream))); assertThat(task, shutdownTask(Boolean.class, "cleanupLeasesOfCompletedShards", equalTo(cleanupLeasesOfCompletedShards))); assertThat(task, shutdownTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis))); diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShutdownTaskTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShutdownTaskTest.java index 220fe4a5..8a9024f6 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShutdownTaskTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShutdownTaskTest.java @@ -26,6 +26,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -33,9 +34,11 @@ import java.util.List; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.awssdk.services.kinesis.model.SequenceNumberRange; import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer; @@ -43,11 +46,15 @@ import software.amazon.kinesis.common.InitialPositionInStream; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.exceptions.internal.KinesisClientLibIOException; import software.amazon.kinesis.leases.HierarchicalShardSyncer; +import software.amazon.kinesis.leases.Lease; import software.amazon.kinesis.leases.LeaseCoordinator; import software.amazon.kinesis.leases.LeaseRefresher; import software.amazon.kinesis.leases.ShardDetector; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardObjectHelper; +import software.amazon.kinesis.leases.exceptions.DependencyException; +import software.amazon.kinesis.leases.exceptions.InvalidStateException; +import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException; import software.amazon.kinesis.lifecycle.events.LeaseLostInput; import software.amazon.kinesis.lifecycle.events.ShardEndedInput; import software.amazon.kinesis.metrics.MetricsFactory; @@ -104,7 +111,7 @@ public class ShutdownTaskTest { task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer, SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, - hierarchicalShardSyncer, NULL_METRICS_FACTORY); + hierarchicalShardSyncer, NULL_METRICS_FACTORY, constructChildShards()); } /** @@ -113,8 +120,8 @@ public class ShutdownTaskTest { */ @Test public final void testCallWhenApplicationDoesNotCheckpoint() { - when(shardDetector.listShards()).thenReturn(constructShardListGraphA()); when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(new ExtendedSequenceNumber("3298")); + when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher); final TaskResult result = task.call(); assertNotNull(result.getException()); @@ -126,28 +133,18 @@ public class ShutdownTaskTest { * This test is for the scenario that checkAndCreateLeaseForNewShards throws an exception. */ @Test - public final void testCallWhenSyncingShardsThrows() throws Exception { - final boolean garbageCollectLeases = false; - final boolean isLeaseTableEmpty = false; - - List latestShards = constructShardListGraphA(); - when(shardDetector.listShards()).thenReturn(latestShards); + public final void testCallWhenCreatingNewLeasesThrows() throws Exception { when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END); when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher); - - doAnswer((invocation) -> { - throw new KinesisClientLibIOException("KinesisClientLibIOException"); - }).when(hierarchicalShardSyncer) - .checkAndCreateLeaseForNewShards(shardDetector, leaseRefresher, INITIAL_POSITION_TRIM_HORIZON, - latestShards, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, - NULL_METRICS_FACTORY.createMetrics(), garbageCollectLeases, isLeaseTableEmpty); + when(leaseRefresher.createLeaseIfNotExists(Matchers.any(Lease.class))).thenThrow(new KinesisClientLibIOException("KinesisClientLibIOException")); final TaskResult result = task.call(); assertNotNull(result.getException()); assertTrue(result.getException() instanceof KinesisClientLibIOException); - verify(recordsPublisher).shutdown(); - verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); + verify(recordsPublisher, never()).shutdown(); + verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build()); + verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class)); } /** @@ -155,24 +152,24 @@ public class ShutdownTaskTest { * This test is for the scenario that ShutdownTask is created for ShardConsumer reaching the Shard End. */ @Test - public final void testCallWhenTrueShardEnd() { + public final void testCallWhenTrueShardEnd() throws DependencyException, InvalidStateException, ProvisionedThroughputException { shardInfo = new ShardInfo("shardId-0", concurrencyToken, Collections.emptySet(), ExtendedSequenceNumber.LATEST); task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer, SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, - hierarchicalShardSyncer, NULL_METRICS_FACTORY); + hierarchicalShardSyncer, NULL_METRICS_FACTORY, constructChildShards()); - when(shardDetector.listShards()).thenReturn(constructShardListGraphA()); when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END); + when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher); final TaskResult result = task.call(); assertNull(result.getException()); verify(recordsPublisher).shutdown(); verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build()); - verify(shardDetector, times(1)).listShards(); - verify(leaseCoordinator, never()).getAssignments(); + verify(leaseRefresher, times(2)).createLeaseIfNotExists(Matchers.any(Lease.class)); + verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class)); } /** @@ -180,23 +177,25 @@ public class ShutdownTaskTest { * This test is for the scenario that a ShutdownTask is created for detecting a false Shard End. */ @Test - public final void testCallWhenFalseShardEnd() { + public final void testCallWhenShardNotFound() throws DependencyException, InvalidStateException, ProvisionedThroughputException { shardInfo = new ShardInfo("shardId-4", concurrencyToken, Collections.emptySet(), ExtendedSequenceNumber.LATEST); task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer, SHARD_END_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, - hierarchicalShardSyncer, NULL_METRICS_FACTORY); + hierarchicalShardSyncer, NULL_METRICS_FACTORY, new ArrayList<>()); - when(shardDetector.listShards()).thenReturn(constructShardListGraphA()); + when(recordProcessorCheckpointer.lastCheckpointValue()).thenReturn(ExtendedSequenceNumber.SHARD_END); + when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher); final TaskResult result = task.call(); assertNull(result.getException()); verify(recordsPublisher).shutdown(); - verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); - verify(shardRecordProcessor).leaseLost(LeaseLostInput.builder().build()); - verify(shardDetector, times(1)).listShards(); - verify(leaseCoordinator).getCurrentlyHeldLease(shardInfo.shardId()); + verify(shardRecordProcessor).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); + verify(shardRecordProcessor, never()).leaseLost(LeaseLostInput.builder().build()); + verify(leaseCoordinator, never()).getCurrentlyHeldLease(shardInfo.shardId()); + verify(leaseRefresher, never()).createLeaseIfNotExists(Matchers.any(Lease.class)); + verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class)); } /** @@ -204,23 +203,22 @@ public class ShutdownTaskTest { * This test is for the scenario that a ShutdownTask is created for the ShardConsumer losing the lease. */ @Test - public final void testCallWhenLeaseLost() { + public final void testCallWhenLeaseLost() throws DependencyException, InvalidStateException, ProvisionedThroughputException { shardInfo = new ShardInfo("shardId-4", concurrencyToken, Collections.emptySet(), ExtendedSequenceNumber.LATEST); task = new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer, LEASE_LOST_SHUTDOWN_REASON, INITIAL_POSITION_TRIM_HORIZON, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, - hierarchicalShardSyncer, NULL_METRICS_FACTORY); - - when(shardDetector.listShards()).thenReturn(constructShardListGraphA()); + hierarchicalShardSyncer, NULL_METRICS_FACTORY, new ArrayList<>()); final TaskResult result = task.call(); assertNull(result.getException()); verify(recordsPublisher).shutdown(); verify(shardRecordProcessor, never()).shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build()); verify(shardRecordProcessor).leaseLost(LeaseLostInput.builder().build()); - verify(shardDetector, never()).listShards(); verify(leaseCoordinator, never()).getAssignments(); + verify(leaseRefresher, never()).createLeaseIfNotExists(Matchers.any(Lease.class)); + verify(leaseCoordinator, never()).dropLease(Matchers.any(Lease.class)); } /** @@ -231,45 +229,23 @@ public class ShutdownTaskTest { assertEquals(TaskType.SHUTDOWN, task.taskType()); } - - /* - * Helper method to construct a shard list for graph A. Graph A is defined below. Shard structure (y-axis is - * epochs): 0 1 2 3 4 5 - shards till - * \ / \ / | | - * 6 7 4 5 - shards from epoch 103 - 205 - * \ / | /\ - * 8 4 9 10 - shards from epoch 206 (open - no ending sequenceNumber) - */ - private List constructShardListGraphA() { - final SequenceNumberRange range0 = ShardObjectHelper.newSequenceNumberRange("11", "102"); - final SequenceNumberRange range1 = ShardObjectHelper.newSequenceNumberRange("11", null); - final SequenceNumberRange range2 = ShardObjectHelper.newSequenceNumberRange("11", "205"); - final SequenceNumberRange range3 = ShardObjectHelper.newSequenceNumberRange("103", "205"); - final SequenceNumberRange range4 = ShardObjectHelper.newSequenceNumberRange("206", null); - - return Arrays.asList( - ShardObjectHelper.newShard("shardId-0", null, null, range0, - ShardObjectHelper.newHashKeyRange("0", "99")), - ShardObjectHelper.newShard("shardId-1", null, null, range0, - ShardObjectHelper.newHashKeyRange("100", "199")), - ShardObjectHelper.newShard("shardId-2", null, null, range0, - ShardObjectHelper.newHashKeyRange("200", "299")), - ShardObjectHelper.newShard("shardId-3", null, null, range0, - ShardObjectHelper.newHashKeyRange("300", "399")), - ShardObjectHelper.newShard("shardId-4", null, null, range1, - ShardObjectHelper.newHashKeyRange("400", "499")), - ShardObjectHelper.newShard("shardId-5", null, null, range2, - ShardObjectHelper.newHashKeyRange("500", ShardObjectHelper.MAX_HASH_KEY)), - ShardObjectHelper.newShard("shardId-6", "shardId-0", "shardId-1", range3, - ShardObjectHelper.newHashKeyRange("0", "199")), - ShardObjectHelper.newShard("shardId-7", "shardId-2", "shardId-3", range3, - ShardObjectHelper.newHashKeyRange("200", "399")), - ShardObjectHelper.newShard("shardId-8", "shardId-6", "shardId-7", range4, - ShardObjectHelper.newHashKeyRange("0", "399")), - ShardObjectHelper.newShard("shardId-9", "shardId-5", null, range4, - ShardObjectHelper.newHashKeyRange("500", "799")), - ShardObjectHelper.newShard("shardId-10", null, "shardId-5", range4, - ShardObjectHelper.newHashKeyRange("800", ShardObjectHelper.MAX_HASH_KEY))); + private List constructChildShards() { + List childShards = new ArrayList<>(); + List parentShards = new ArrayList<>(); + parentShards.add(shardId); + ChildShard leftChild = ChildShard.builder() + .shardId("ShardId-1") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49")) + .build(); + ChildShard rightChild = ChildShard.builder() + .shardId("ShardId-2") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99")) + .build(); + childShards.add(leftChild); + childShards.add(rightChild); + return childShards; } } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java index fe6489b9..43881122 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java @@ -26,6 +26,7 @@ import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; @@ -35,6 +36,7 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.kinesis.common.InitialPositionInStream; import software.amazon.kinesis.common.InitialPositionInStreamExtended; +import software.amazon.kinesis.leases.ShardObjectHelper; import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.retrieval.BatchUniqueIdentifier; @@ -47,6 +49,7 @@ import software.amazon.kinesis.utils.SubscribeToShardRequestMatcher; import java.nio.ByteBuffer; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Optional; @@ -89,6 +92,7 @@ public class FanOutRecordsPublisherTest { private static final String SHARD_ID = "Shard-001"; private static final String CONSUMER_ARN = "arn:consumer"; + private static final String CONTINUATION_SEQUENCE_NUMBER = "continuationSequenceNumber"; @Mock private KinesisAsyncClient kinesisClient; @@ -148,7 +152,12 @@ public class FanOutRecordsPublisherTest { List matchers = records.stream().map(KinesisClientRecordMatcher::new) .collect(Collectors.toList()); - batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build(); + batchEvent = SubscribeToShardEvent.builder() + .millisBehindLatest(100L) + .records(records) + .continuationSequenceNumber("test") + .childShards(Collections.emptyList()) + .build(); captor.getValue().onNext(batchEvent); captor.getValue().onNext(batchEvent); @@ -166,6 +175,73 @@ public class FanOutRecordsPublisherTest { } + @Test + public void InvalidEventTest() throws Exception { + FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); + + ArgumentCaptor captor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordSubscription.class); + ArgumentCaptor flowCaptor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordFlow.class); + + doNothing().when(publisher).subscribe(captor.capture()); + + source.start(ExtendedSequenceNumber.LATEST, + InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST)); + + List receivedInput = new ArrayList<>(); + + source.subscribe(new ShardConsumerNotifyingSubscriber(new Subscriber() { + Subscription subscription; + + @Override public void onSubscribe(Subscription s) { + subscription = s; + subscription.request(1); + } + + @Override public void onNext(RecordsRetrieved input) { + receivedInput.add(input.processRecordsInput()); + subscription.request(1); + } + + @Override public void onError(Throwable t) { + log.error("Caught throwable in subscriber", t); + fail("Caught throwable in subscriber"); + } + + @Override public void onComplete() { + fail("OnComplete called when not expected"); + } + }, source)); + + verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture()); + flowCaptor.getValue().onEventStream(publisher); + captor.getValue().onSubscribe(subscription); + + List records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList()); + List matchers = records.stream().map(KinesisClientRecordMatcher::new) + .collect(Collectors.toList()); + + batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).continuationSequenceNumber(CONTINUATION_SEQUENCE_NUMBER).build(); + SubscribeToShardEvent invalidEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).childShards(Collections.emptyList()).build(); + + captor.getValue().onNext(batchEvent); + captor.getValue().onNext(invalidEvent); + captor.getValue().onNext(batchEvent); + + // When the second request failed with invalid event, it should stop sending requests and cancel the flow. + verify(subscription, times(2)).request(1); + assertThat(receivedInput.size(), equalTo(1)); + + receivedInput.stream().map(ProcessRecordsInput::records).forEach(clientRecordsList -> { + assertThat(clientRecordsList.size(), equalTo(matchers.size())); + for (int i = 0; i < clientRecordsList.size(); ++i) { + assertThat(clientRecordsList.get(i), matchers.get(i)); + } + }); + + } + @Test public void testIfAllEventsReceivedWhenNoTasksRejectedByExecutor() throws Exception { FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); @@ -225,7 +301,9 @@ public class FanOutRecordsPublisherTest { SubscribeToShardEvent.builder() .millisBehindLatest(100L) .continuationSequenceNumber(contSeqNum) - .records(records).build()) + .records(records) + .childShards(Collections.emptyList()) + .build()) .forEach(batchEvent -> captor.getValue().onNext(batchEvent)); verify(subscription, times(4)).request(1); @@ -301,7 +379,9 @@ public class FanOutRecordsPublisherTest { SubscribeToShardEvent.builder() .millisBehindLatest(100L) .continuationSequenceNumber(contSeqNum) - .records(records).build()) + .records(records) + .childShards(Collections.emptyList()) + .build()) .forEach(batchEvent -> captor.getValue().onNext(batchEvent)); verify(subscription, times(2)).request(1); @@ -334,6 +414,7 @@ public class FanOutRecordsPublisherTest { .millisBehindLatest(100L) .continuationSequenceNumber(contSeqNum + "") .records(records) + .childShards(Collections.emptyList()) .build()); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); @@ -436,6 +517,7 @@ public class FanOutRecordsPublisherTest { .millisBehindLatest(100L) .continuationSequenceNumber(contSeqNum + "") .records(records) + .childShards(Collections.emptyList()) .build()); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); @@ -536,13 +618,30 @@ public class FanOutRecordsPublisherTest { .millisBehindLatest(100L) .continuationSequenceNumber(contSeqNum + "") .records(records) + .childShards(Collections.emptyList()) .build()); + List childShards = new ArrayList<>(); + List parentShards = new ArrayList<>(); + parentShards.add(SHARD_ID); + ChildShard leftChild = ChildShard.builder() + .shardId("Shard-002") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49")) + .build(); + ChildShard rightChild = ChildShard.builder() + .shardId("Shard-003") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99")) + .build(); + childShards.add(leftChild); + childShards.add(rightChild); Consumer servicePublisherShardEndAction = contSeqNum -> captor.getValue().onNext( SubscribeToShardEvent.builder() .millisBehindLatest(100L) .continuationSequenceNumber(null) .records(records) + .childShards(childShards) .build()); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); @@ -648,6 +747,7 @@ public class FanOutRecordsPublisherTest { .millisBehindLatest(100L) .continuationSequenceNumber(contSeqNum + "") .records(records) + .childShards(Collections.emptyList()) .build()); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); @@ -750,6 +850,7 @@ public class FanOutRecordsPublisherTest { .millisBehindLatest(100L) .continuationSequenceNumber(contSeqNum + "") .records(records) + .childShards(Collections.emptyList()) .build()); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(2); @@ -842,6 +943,7 @@ public class FanOutRecordsPublisherTest { .millisBehindLatest(100L) .continuationSequenceNumber(contSeqNum + "") .records(records) + .childShards(Collections.emptyList()) .build()); CountDownLatch servicePublisherTaskCompletionLatch = new CountDownLatch(1); @@ -1004,7 +1106,12 @@ public class FanOutRecordsPublisherTest { List matchers = records.stream().map(KinesisClientRecordMatcher::new) .collect(Collectors.toList()); - batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build(); + batchEvent = SubscribeToShardEvent.builder() + .millisBehindLatest(100L) + .records(records) + .continuationSequenceNumber(CONTINUATION_SEQUENCE_NUMBER) + .childShards(Collections.emptyList()) + .build(); captor.getValue().onNext(batchEvent); captor.getValue().onNext(batchEvent); @@ -1098,7 +1205,7 @@ public class FanOutRecordsPublisherTest { .collect(Collectors.toList()); batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records) - .continuationSequenceNumber("3").build(); + .continuationSequenceNumber("3").childShards(Collections.emptyList()).build(); captor.getValue().onNext(batchEvent); captor.getValue().onComplete(); @@ -1126,7 +1233,7 @@ public class FanOutRecordsPublisherTest { .collect(Collectors.toList()); batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(nextRecords) - .continuationSequenceNumber("6").build(); + .continuationSequenceNumber("6").childShards(Collections.emptyList()).build(); nextSubscribeCaptor.getValue().onNext(batchEvent); verify(subscription, times(4)).request(1); diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcherTest.java index a88f3c3b..74b0c125 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcherTest.java @@ -30,6 +30,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Date; @@ -53,6 +54,7 @@ import org.mockito.runners.MockitoJUnitRunner; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest; import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; @@ -65,6 +67,7 @@ import software.amazon.kinesis.checkpoint.SentinelCheckpoint; import software.amazon.kinesis.common.InitialPositionInStream; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.exceptions.KinesisClientLibException; +import software.amazon.kinesis.leases.ShardObjectHelper; import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.NullMetricsFactory; import software.amazon.kinesis.processor.Checkpointer; @@ -330,8 +333,31 @@ public class KinesisDataFetcherTest { private CompletableFuture makeGetRecordsResponse(String nextIterator, List records) throws InterruptedException, ExecutionException { + List childShards = new ArrayList<>(); + if(nextIterator == null) { + childShards = createChildShards(); + } return CompletableFuture.completedFuture(GetRecordsResponse.builder().nextShardIterator(nextIterator) - .records(CollectionUtils.isNullOrEmpty(records) ? Collections.emptyList() : records).build()); + .records(CollectionUtils.isNullOrEmpty(records) ? Collections.emptyList() : records).childShards(childShards).build()); + } + + private List createChildShards() { + List childShards = new ArrayList<>(); + List parentShards = new ArrayList<>(); + parentShards.add(SHARD_ID); + ChildShard leftChild = ChildShard.builder() + .shardId("Shard-2") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49")) + .build(); + ChildShard rightChild = ChildShard.builder() + .shardId("Shard-3") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99")) + .build(); + childShards.add(leftChild); + childShards.add(rightChild); + return childShards; } @Test @@ -342,6 +368,7 @@ public class KinesisDataFetcherTest { final String initialIterator = "InitialIterator"; final String nextIterator1 = "NextIteratorOne"; final String nextIterator2 = "NextIteratorTwo"; + final String nextIterator3 = "NextIteratorThree"; final CompletableFuture nonAdvancingResult1 = makeGetRecordsResponse(initialIterator, null); final CompletableFuture nonAdvancingResult2 = makeGetRecordsResponse(nextIterator1, null); final CompletableFuture finalNonAdvancingResult = makeGetRecordsResponse(nextIterator2, diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherIntegrationTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherIntegrationTest.java index f940faf2..461fce71 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherIntegrationTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherIntegrationTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -48,6 +49,7 @@ import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import lombok.extern.slf4j.Slf4j; @@ -86,6 +88,7 @@ public class PrefetchRecordsPublisherIntegrationTest { private String operation = "ProcessTask"; private String streamName = "streamName"; private String shardId = "shardId-000000000000"; + private String nextShardIterator = "testNextShardIterator"; @Mock private KinesisAsyncClient kinesisClient; @@ -249,7 +252,7 @@ public class PrefetchRecordsPublisherIntegrationTest { @Override public DataFetcherResult getRecords() { - GetRecordsResponse getRecordsResult = GetRecordsResponse.builder().records(new ArrayList<>(records)).millisBehindLatest(1000L).build(); + GetRecordsResponse getRecordsResult = GetRecordsResponse.builder().records(new ArrayList<>(records)).nextShardIterator(nextShardIterator).millisBehindLatest(1000L).build(); return new AdvancingResult(getRecordsResult); } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java index a28ded63..f7051ec4 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; @@ -71,11 +72,13 @@ import io.reactivex.Flowable; import io.reactivex.schedulers.Schedulers; import lombok.extern.slf4j.Slf4j; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.RequestDetails; +import software.amazon.kinesis.leases.ShardObjectHelper; import software.amazon.kinesis.common.StreamIdentifier; import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; @@ -99,6 +102,7 @@ public class PrefetchRecordsPublisherTest { private static final int MAX_SIZE = 5; private static final int MAX_RECORDS_COUNT = 15000; private static final long IDLE_MILLIS_BETWEEN_CALLS = 0L; + private static final String NEXT_SHARD_ITERATOR = "testNextShardIterator"; @Mock private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy; @@ -136,7 +140,7 @@ public class PrefetchRecordsPublisherTest { "shardId"); spyQueue = spy(getRecordsCache.getPublisherSession().prefetchRecordsQueue()); records = spy(new ArrayList<>()); - getRecordsResponse = GetRecordsResponse.builder().records(records).build(); + getRecordsResponse = GetRecordsResponse.builder().records(records).nextShardIterator(NEXT_SHARD_ITERATOR).childShards(new ArrayList<>()).build(); when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(getRecordsResponse); } @@ -155,11 +159,67 @@ public class PrefetchRecordsPublisherTest { .processRecordsInput(); assertEquals(expectedRecords, result.records()); + assertEquals(new ArrayList<>(), result.childShards()); verify(executorService).execute(any()); verify(getRecordsRetrievalStrategy, atLeast(1)).getRecords(eq(MAX_RECORDS_PER_CALL)); } + @Test + public void testGetRecordsWithInvalidResponse() { + record = Record.builder().data(createByteBufferWithSize(SIZE_512_KB)).build(); + + when(records.size()).thenReturn(1000); + + GetRecordsResponse response = GetRecordsResponse.builder().records(records).build(); + when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(response); + when(dataFetcher.isShardEndReached()).thenReturn(false); + + getRecordsCache.start(sequenceNumber, initialPosition); + + try { + ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L) + .processRecordsInput(); + } catch (Exception e) { + assertEquals("No records found", e.getMessage()); + } + } + + @Test + public void testGetRecordsWithShardEnd() { + records = new ArrayList<>(); + + final List expectedRecords = new ArrayList<>(); + + List childShards = new ArrayList<>(); + List parentShards = new ArrayList<>(); + parentShards.add("shardId"); + ChildShard leftChild = ChildShard.builder() + .shardId("shardId-000000000001") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49")) + .build(); + ChildShard rightChild = ChildShard.builder() + .shardId("shardId-000000000002") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99")) + .build(); + childShards.add(leftChild); + childShards.add(rightChild); + + GetRecordsResponse response = GetRecordsResponse.builder().records(records).childShards(childShards).build(); + when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenReturn(response); + when(dataFetcher.isShardEndReached()).thenReturn(true); + + getRecordsCache.start(sequenceNumber, initialPosition); + ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L) + .processRecordsInput(); + + assertEquals(expectedRecords, result.records()); + assertEquals(childShards, result.childShards()); + assertTrue(result.isAtShardEnd()); + } + // TODO: Broken test @Test @Ignore @@ -270,7 +330,7 @@ public class PrefetchRecordsPublisherTest { @Test public void testRetryableRetrievalExceptionContinues() { - GetRecordsResponse response = GetRecordsResponse.builder().millisBehindLatest(100L).records(Collections.emptyList()).build(); + GetRecordsResponse response = GetRecordsResponse.builder().millisBehindLatest(100L).records(Collections.emptyList()).nextShardIterator(NEXT_SHARD_ITERATOR).build(); when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenThrow(new RetryableRetrievalException("Timeout", new TimeoutException("Timeout"))).thenReturn(response); getRecordsCache.start(sequenceNumber, initialPosition); @@ -293,7 +353,7 @@ public class PrefetchRecordsPublisherTest { when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenAnswer( i -> GetRecordsResponse.builder().records( Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber(++sequenceNumberInResponse[0] + "").build()) - .build()); + .nextShardIterator(NEXT_SHARD_ITERATOR).build()); getRecordsCache.start(sequenceNumber, initialPosition); @@ -384,7 +444,7 @@ public class PrefetchRecordsPublisherTest { // to the subscriber. GetRecordsResponse response = GetRecordsResponse.builder().records( Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber("123").build()) - .build(); + .nextShardIterator(NEXT_SHARD_ITERATOR).build(); when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(response); getRecordsCache.start(sequenceNumber, initialPosition); diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/utils/ProcessRecordsInputMatcher.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/utils/ProcessRecordsInputMatcher.java index a89ebef6..1aeddc60 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/utils/ProcessRecordsInputMatcher.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/utils/ProcessRecordsInputMatcher.java @@ -45,6 +45,7 @@ public class ProcessRecordsInputMatcher extends TypeSafeDiagnosingMatcher