From 7b026f8a193b37123784e555a2096890aa5c45b6 Mon Sep 17 00:00:00 2001 From: "Pfifer, Justin" Date: Mon, 26 Mar 2018 11:47:44 -0700 Subject: [PATCH] Broke apart the process task, now diving into the shard consumer more. --- .../kinesis/lifecycle/ConsumerStates.java | 5 +- .../lifecycle/ProcessRecordsInput.java | 4 + .../amazon/kinesis/lifecycle/ProcessTask.java | 213 ++++------ .../lifecycle/RecordProcessorLifecycle.java | 32 ++ .../lifecycle/RecordProcessorShim.java | 54 +++ .../kinesis/lifecycle/ShardConsumer.java | 167 +++++--- .../kinesis/lifecycle/events/LeaseLost.java | 18 + .../lifecycle/events/RecordsReceived.java | 22 ++ .../lifecycle/events/ShardCompleted.java | 18 + .../lifecycle/events/ShutdownRequested.java | 18 + .../kinesis/lifecycle/events/Started.java | 33 ++ .../kinesis/lifecycle/ProcessTaskTest.java | 366 +++++++++--------- 12 files changed, 580 insertions(+), 370 deletions(-) create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/RecordProcessorLifecycle.java create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/RecordProcessorShim.java create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/LeaseLost.java create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/RecordsReceived.java create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ShardCompleted.java create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ShutdownRequested.java create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/Started.java diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java index e192a505..ab941938 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ConsumerStates.java @@ -312,14 +312,13 @@ class ConsumerStates { @Override public ITask createTask(ShardConsumer consumer) { + ProcessTask.RecordsFetcher recordsFetcher = new ProcessTask.RecordsFetcher(consumer.getGetRecordsCache()); return new ProcessTask(consumer.getShardInfo(), consumer.getStreamConfig(), consumer.getRecordProcessor(), consumer.getRecordProcessorCheckpointer(), - consumer.getDataFetcher(), consumer.getTaskBackoffTimeMillis(), - consumer.isSkipShardSyncAtWorkerInitializationIfLeasesExist(), - consumer.getGetRecordsCache()); + consumer.isSkipShardSyncAtWorkerInitializationIfLeasesExist(), recordsFetcher.getRecords()); } @Override diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ProcessRecordsInput.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ProcessRecordsInput.java index 5bb47cd1..96008359 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ProcessRecordsInput.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ProcessRecordsInput.java @@ -18,6 +18,7 @@ import java.time.Duration; import java.time.Instant; import java.util.List; +import lombok.AllArgsConstructor; import software.amazon.kinesis.processor.IRecordProcessorCheckpointer; import com.amazonaws.services.kinesis.model.Record; @@ -29,11 +30,14 @@ import software.amazon.kinesis.processor.IRecordProcessor; * {@link IRecordProcessor#processRecords( * ProcessRecordsInput processRecordsInput) processRecords} method. */ +@AllArgsConstructor public class ProcessRecordsInput { @Getter private Instant cacheEntryTime; @Getter private Instant cacheExitTime; + @Getter + private boolean isAtShardEnd; private List records; private IRecordProcessorCheckpointer checkpointer; private Long millisBehindLatest; diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ProcessTask.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ProcessTask.java index 5076dc6f..04f56fcc 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ProcessTask.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ProcessTask.java @@ -1,16 +1,9 @@ /* - * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Amazon Software License (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/asl/ - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. + * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Amazon Software License + * (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at + * http://aws.amazon.com/asl/ or in the "license" file accompanying this file. This file is distributed on an "AS IS" + * BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. */ package software.amazon.kinesis.lifecycle; @@ -19,26 +12,24 @@ import java.util.List; import java.util.ListIterator; import com.amazonaws.services.cloudwatch.model.StandardUnit; +import com.amazonaws.services.kinesis.model.Record; +import com.amazonaws.services.kinesis.model.Shard; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import software.amazon.kinesis.coordinator.RecordProcessorCheckpointer; -import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.coordinator.StreamConfig; -import software.amazon.kinesis.retrieval.ThrottlingReporter; +import software.amazon.kinesis.leases.ShardInfo; +import software.amazon.kinesis.metrics.IMetricsScope; +import software.amazon.kinesis.metrics.MetricsHelper; +import software.amazon.kinesis.metrics.MetricsLevel; import software.amazon.kinesis.processor.IRecordProcessor; import software.amazon.kinesis.retrieval.GetRecordsCache; import software.amazon.kinesis.retrieval.IKinesisProxy; import software.amazon.kinesis.retrieval.IKinesisProxyExtended; -import software.amazon.kinesis.retrieval.KinesisDataFetcher; +import software.amazon.kinesis.retrieval.ThrottlingReporter; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; import software.amazon.kinesis.retrieval.kpl.UserRecord; -import software.amazon.kinesis.metrics.MetricsHelper; -import software.amazon.kinesis.metrics.IMetricsScope; -import software.amazon.kinesis.metrics.MetricsLevel; -import com.amazonaws.services.kinesis.model.ExpiredIteratorException; -import com.amazonaws.services.kinesis.model.ProvisionedThroughputExceededException; -import com.amazonaws.services.kinesis.model.Record; -import com.amazonaws.services.kinesis.model.Shard; - -import lombok.extern.slf4j.Slf4j; /** * Task for fetching data records and invoking processRecords() on the record processor instance. @@ -55,39 +46,30 @@ public class ProcessTask implements ITask { private final ShardInfo shardInfo; private final IRecordProcessor recordProcessor; private final RecordProcessorCheckpointer recordProcessorCheckpointer; - private final KinesisDataFetcher dataFetcher; private final TaskType taskType = TaskType.PROCESS; private final StreamConfig streamConfig; private final long backoffTimeMillis; private final Shard shard; private final ThrottlingReporter throttlingReporter; - private final GetRecordsCache getRecordsCache; + private final ProcessRecordsInput processRecordsInput; + + @RequiredArgsConstructor + public static class RecordsFetcher { + + private final GetRecordsCache getRecordsCache; + + public ProcessRecordsInput getRecords() { + ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult(); + + if (processRecordsInput.getMillisBehindLatest() != null) { + MetricsHelper.getMetricsScope().addData(MILLIS_BEHIND_LATEST_METRIC, + processRecordsInput.getMillisBehindLatest(), StandardUnit.Milliseconds, MetricsLevel.SUMMARY); + } + + return processRecordsInput; + } - /** - * @param shardInfo - * contains information about the shard - * @param streamConfig - * Stream configuration - * @param recordProcessor - * Record processor used to process the data records for the shard - * @param recordProcessorCheckpointer - * Passed to the RecordProcessor so it can checkpoint progress - * @param dataFetcher - * Kinesis data fetcher (used to fetch records from Kinesis) - * @param backoffTimeMillis - * backoff time when catching exceptions - * @param getRecordsCache - * The retrieval strategy for fetching records from kinesis - */ - public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, - RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, - long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist, - GetRecordsCache getRecordsCache) { - this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis, - skipShardSyncAtWorkerInitializationIfLeasesExist, - new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), - getRecordsCache); } /** @@ -99,27 +81,44 @@ public class ProcessTask implements ITask { * Record processor used to process the data records for the shard * @param recordProcessorCheckpointer * Passed to the RecordProcessor so it can checkpoint progress - * @param dataFetcher - * Kinesis data fetcher (used to fetch records from Kinesis) + * @param backoffTimeMillis + * backoff time when catching exceptions + */ + public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, + RecordProcessorCheckpointer recordProcessorCheckpointer, long backoffTimeMillis, + boolean skipShardSyncAtWorkerInitializationIfLeasesExist, ProcessRecordsInput processRecordsInput) { + this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, backoffTimeMillis, + skipShardSyncAtWorkerInitializationIfLeasesExist, + new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), processRecordsInput); + } + + /** + * @param shardInfo + * contains information about the shard + * @param streamConfig + * Stream configuration + * @param recordProcessor + * Record processor used to process the data records for the shard + * @param recordProcessorCheckpointer + * Passed to the RecordProcessor so it can checkpoint progress * @param backoffTimeMillis * backoff time when catching exceptions * @param throttlingReporter * determines how throttling events should be reported in the log. */ public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, - RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, - long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist, - ThrottlingReporter throttlingReporter, GetRecordsCache getRecordsCache) { + RecordProcessorCheckpointer recordProcessorCheckpointer, long backoffTimeMillis, + boolean skipShardSyncAtWorkerInitializationIfLeasesExist, ThrottlingReporter throttlingReporter, + ProcessRecordsInput processRecordsInput) { super(); this.shardInfo = shardInfo; this.recordProcessor = recordProcessor; this.recordProcessorCheckpointer = recordProcessorCheckpointer; - this.dataFetcher = dataFetcher; this.streamConfig = streamConfig; this.backoffTimeMillis = backoffTimeMillis; this.throttlingReporter = throttlingReporter; IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy(); - this.getRecordsCache = getRecordsCache; + this.processRecordsInput = processRecordsInput; // If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for // this ProcessTask. In this case, duplicate KPL user records in the event of resharding will // not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if @@ -138,7 +137,6 @@ public class ProcessTask implements ITask { /* * (non-Javadoc) - * * @see com.amazonaws.services.kinesis.clientlibrary.lib.worker.ITask#call() */ @Override @@ -151,12 +149,11 @@ public class ProcessTask implements ITask { Exception exception = null; try { - if (dataFetcher.isShardEndReached()) { + if (processRecordsInput.isAtShardEnd()) { log.info("Reached end of shard {}", shardInfo.getShardId()); return new TaskResult(null, true); } - final ProcessRecordsInput processRecordsInput = getRecordsResult(); throttlingReporter.success(); List records = processRecordsInput.getRecords(); @@ -167,19 +164,13 @@ public class ProcessTask implements ITask { } records = deaggregateRecords(records); - recordProcessorCheckpointer.setLargestPermittedCheckpointValue( - filterAndGetMaxExtendedSequenceNumber(scope, records, - recordProcessorCheckpointer.getLastCheckpointValue(), - recordProcessorCheckpointer.getLargestPermittedCheckpointValue())); + recordProcessorCheckpointer.setLargestPermittedCheckpointValue(filterAndGetMaxExtendedSequenceNumber(scope, + records, recordProcessorCheckpointer.getLastCheckpointValue(), + recordProcessorCheckpointer.getLargestPermittedCheckpointValue())); if (shouldCallProcessRecords(records)) { callProcessRecords(processRecordsInput, records); } - } catch (ProvisionedThroughputExceededException pte) { - throttlingReporter.throttled(); - exception = pte; - backoff(); - } catch (RuntimeException e) { log.error("ShardId {}: Caught exception: ", shardInfo.getShardId(), e); exception = e; @@ -213,8 +204,7 @@ public class ProcessTask implements ITask { log.debug("Calling application processRecords() with {} records from {}", records.size(), shardInfo.getShardId()); final ProcessRecordsInput processRecordsInput = new ProcessRecordsInput().withRecords(records) - .withCheckpointer(recordProcessorCheckpointer) - .withMillisBehindLatest(input.getMillisBehindLatest()); + .withCheckpointer(recordProcessorCheckpointer).withMillisBehindLatest(input.getMillisBehindLatest()); final long recordProcessorStartTimeMillis = System.currentTimeMillis(); try { @@ -292,27 +282,28 @@ public class ProcessTask implements ITask { } /** - * Scans a list of records to filter out records up to and including the most recent checkpoint value and to get - * the greatest extended sequence number from the retained records. Also emits metrics about the records. + * Scans a list of records to filter out records up to and including the most recent checkpoint value and to get the + * greatest extended sequence number from the retained records. Also emits metrics about the records. * - * @param scope metrics scope to emit metrics into - * @param records list of records to scan and change in-place as needed - * @param lastCheckpointValue the most recent checkpoint value - * @param lastLargestPermittedCheckpointValue previous largest permitted checkpoint value + * @param scope + * metrics scope to emit metrics into + * @param records + * list of records to scan and change in-place as needed + * @param lastCheckpointValue + * the most recent checkpoint value + * @param lastLargestPermittedCheckpointValue + * previous largest permitted checkpoint value * @return the largest extended sequence number among the retained records */ private ExtendedSequenceNumber filterAndGetMaxExtendedSequenceNumber(IMetricsScope scope, List records, - final ExtendedSequenceNumber lastCheckpointValue, - final ExtendedSequenceNumber lastLargestPermittedCheckpointValue) { + final ExtendedSequenceNumber lastCheckpointValue, + final ExtendedSequenceNumber lastLargestPermittedCheckpointValue) { ExtendedSequenceNumber largestExtendedSequenceNumber = lastLargestPermittedCheckpointValue; ListIterator recordIterator = records.listIterator(); while (recordIterator.hasNext()) { Record record = recordIterator.next(); - ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber( - record.getSequenceNumber(), - record instanceof UserRecord - ? ((UserRecord) record).getSubSequenceNumber() - : null); + ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber(record.getSequenceNumber(), + record instanceof UserRecord ? ((UserRecord) record).getSubSequenceNumber() : null); if (extendedSequenceNumber.compareTo(lastCheckpointValue) <= 0) { recordIterator.remove(); @@ -332,58 +323,4 @@ public class ProcessTask implements ITask { return largestExtendedSequenceNumber; } - /** - * Gets records from Kinesis and retries once in the event of an ExpiredIteratorException. - * - * @return list of data records from Kinesis - */ - private ProcessRecordsInput getRecordsResult() { - try { - return getRecordsResultAndRecordMillisBehindLatest(); - } catch (ExpiredIteratorException e) { - // If we see a ExpiredIteratorException, try once to restart from the greatest remembered sequence number - log.info("ShardId {}" - + ": getRecords threw ExpiredIteratorException - restarting after greatest seqNum " - + "passed to customer", shardInfo.getShardId(), e); - MetricsHelper.getMetricsScope().addData(EXPIRED_ITERATOR_METRIC, 1, StandardUnit.Count, - MetricsLevel.SUMMARY); - - /* - * Advance the iterator to after the greatest processed sequence number (remembered by - * recordProcessorCheckpointer). - */ - dataFetcher.advanceIteratorTo(recordProcessorCheckpointer.getLargestPermittedCheckpointValue() - .getSequenceNumber(), streamConfig.getInitialPositionInStream()); - - // Try a second time - if we fail this time, expose the failure. - try { - return getRecordsResultAndRecordMillisBehindLatest(); - } catch (ExpiredIteratorException ex) { - String msg = - "Shard " + shardInfo.getShardId() - + ": getRecords threw ExpiredIteratorException with a fresh iterator."; - log.error(msg, ex); - throw ex; - } - } - } - - /** - * Gets records from Kinesis and records the MillisBehindLatest metric if present. - * - * @return list of data records from Kinesis - */ - private ProcessRecordsInput getRecordsResultAndRecordMillisBehindLatest() { - final ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult(); - - if (processRecordsInput.getMillisBehindLatest() != null) { - MetricsHelper.getMetricsScope().addData(MILLIS_BEHIND_LATEST_METRIC, - processRecordsInput.getMillisBehindLatest(), - StandardUnit.Milliseconds, - MetricsLevel.SUMMARY); - } - - return processRecordsInput; - } - } \ No newline at end of file diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/RecordProcessorLifecycle.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/RecordProcessorLifecycle.java new file mode 100644 index 00000000..db63f88b --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/RecordProcessorLifecycle.java @@ -0,0 +1,32 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle; + +import software.amazon.kinesis.lifecycle.events.LeaseLost; +import software.amazon.kinesis.lifecycle.events.RecordsReceived; +import software.amazon.kinesis.lifecycle.events.ShardCompleted; +import software.amazon.kinesis.lifecycle.events.ShutdownRequested; +import software.amazon.kinesis.lifecycle.events.Started; + +public interface RecordProcessorLifecycle { + + void started(Started started); + void recordsReceived(RecordsReceived records); + void leaseLost(LeaseLost leaseLost); + void shardCompleted(ShardCompleted shardCompletedInput); + void shutdownRequested(ShutdownRequested shutdownRequested); + + +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/RecordProcessorShim.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/RecordProcessorShim.java new file mode 100644 index 00000000..7d906991 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/RecordProcessorShim.java @@ -0,0 +1,54 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle; + +import lombok.AllArgsConstructor; +import software.amazon.kinesis.lifecycle.events.LeaseLost; +import software.amazon.kinesis.lifecycle.events.ShardCompleted; +import software.amazon.kinesis.lifecycle.events.ShutdownRequested; +import software.amazon.kinesis.lifecycle.events.Started; +import software.amazon.kinesis.processor.IRecordProcessor; + +@AllArgsConstructor +public class RecordProcessorShim implements RecordProcessorLifecycle { + + private final IRecordProcessor delegate; + + @Override + public void started(Started started) { + InitializationInput initializationInput = started.toInitializationInput(); + delegate.initialize(initializationInput); + } + + @Override + public void recordsReceived(ProcessRecordsInput records) { + + } + + @Override + public void leaseLost(LeaseLost leaseLost) { + + } + + @Override + public void shardCompleted(ShardCompleted shardCompletedInput) { + + } + + @Override + public void shutdownRequested(ShutdownRequested shutdownRequested) { + + } +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java index d5e30b76..6fcd2a82 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java @@ -16,13 +16,20 @@ package software.amazon.kinesis.lifecycle; import java.util.Optional; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadPoolExecutor; import com.amazonaws.services.kinesis.clientlibrary.exceptions.internal.BlockedOnParentShardException; import software.amazon.kinesis.coordinator.KinesisClientLibConfiguration; import software.amazon.kinesis.checkpoint.Checkpoint; +import software.amazon.kinesis.lifecycle.events.LeaseLost; +import software.amazon.kinesis.lifecycle.events.ShardCompleted; +import software.amazon.kinesis.lifecycle.events.ShutdownRequested; +import software.amazon.kinesis.lifecycle.events.Started; import software.amazon.kinesis.metrics.MetricsCollectingTaskDecorator; import software.amazon.kinesis.coordinator.RecordProcessorCheckpointer; import software.amazon.kinesis.leases.ShardInfo; @@ -48,9 +55,11 @@ import software.amazon.kinesis.retrieval.SynchronousGetRecordsRetrievalStrategy; * A new instance should be created if the primary responsibility is reassigned back to this process. */ @Slf4j -public class ShardConsumer { +public class ShardConsumer implements RecordProcessorLifecycle { + // private final StreamConfig streamConfig; private final IRecordProcessor recordProcessor; + private RecordProcessorLifecycle recordProcessorLifecycle; private final KinesisClientLibConfiguration config; private final RecordProcessorCheckpointer recordProcessorCheckpointer; private final ExecutorService executorService; @@ -68,7 +77,9 @@ public class ShardConsumer { private ITask currentTask; private long currentTaskSubmitTime; private Future future; - + // + + // @Getter private final GetRecordsCache getRecordsCache; @@ -82,6 +93,7 @@ public class ShardConsumer { return getRecordsRetrievalStrategy.orElse(new SynchronousGetRecordsRetrievalStrategy(dataFetcher)); } + // /* * Tracks current state. It is only updated via the consumeStream/shutdown APIs. Therefore we don't do @@ -95,6 +107,7 @@ public class ShardConsumer { private volatile ShutdownReason shutdownReason; private volatile ShutdownNotification shutdownNotification; + // /** * @param shardInfo Shard information * @param streamConfig Stream configuration to use @@ -245,7 +258,9 @@ public class ShardConsumer { makeStrategy(this.dataFetcher, retryGetRecordsInSeconds, maxGetRecordsThreadPool, this.shardInfo), this.getShardInfo().getShardId(), this.metricsFactory, this.config.getMaxRecords()); } + // + // /** * No-op if current task is pending, otherwise submits next task for this shard. * This method should NOT be called if the ShardConsumer is already in SHUTDOWN_COMPLETED state. @@ -345,57 +360,9 @@ public class ShardConsumer { } } - /** - * Requests the shutdown of the this ShardConsumer. This should give the record processor a chance to checkpoint - * before being shutdown. - * - * @param shutdownNotification used to signal that the record processor has been given the chance to shutdown. - */ - public void notifyShutdownRequested(ShutdownNotification shutdownNotification) { - this.shutdownNotification = shutdownNotification; - markForShutdown(ShutdownReason.REQUESTED); - } - - /** - * Shutdown this ShardConsumer (including invoking the RecordProcessor shutdown API). - * This is called by Worker when it loses responsibility for a shard. - * - * @return true if shutdown is complete (false if shutdown is still in progress) - */ - public synchronized boolean beginShutdown() { - markForShutdown(ShutdownReason.ZOMBIE); - checkAndSubmitNextTask(); - - return isShutdown(); - } - - synchronized void markForShutdown(ShutdownReason reason) { - // ShutdownReason.ZOMBIE takes precedence over TERMINATE (we won't be able to save checkpoint at end of shard) - if (shutdownReason == null || shutdownReason.canTransitionTo(reason)) { - shutdownReason = reason; - } - } - - /** - * Used (by Worker) to check if this ShardConsumer instance has been shutdown - * RecordProcessor shutdown() has been invoked, as appropriate. - * - * @return true if shutdown is complete - */ - public boolean isShutdown() { - return currentState.isTerminal(); - } - - /** - * @return the shutdownReason - */ - public ShutdownReason getShutdownReason() { - return shutdownReason; - } - /** * Figure out next task to run based on current state, task, and shutdown context. - * + * * @return Return next task to run */ private ITask getNextTask() { @@ -411,7 +378,7 @@ public class ShardConsumer { /** * Note: This is a private/internal method with package level access solely for testing purposes. * Update state based on information about: task success, current state, and shutdown info. - * + * * @param taskOutcome The outcome of the last task */ void updateState(TaskOutcome taskOutcome) { @@ -435,11 +402,65 @@ public class ShardConsumer { } + // + + // + /** + * Requests the shutdown of the this ShardConsumer. This should give the record processor a chance to checkpoint + * before being shutdown. + * + * @param shutdownNotification used to signal that the record processor has been given the chance to shutdown. + */ + public void notifyShutdownRequested(ShutdownNotification shutdownNotification) { + this.shutdownNotification = shutdownNotification; + markForShutdown(ShutdownReason.REQUESTED); + } + + /** + * Shutdown this ShardConsumer (including invoking the RecordProcessor shutdown API). + * This is called by Worker when it loses responsibility for a shard. + * + * @return true if shutdown is complete (false if shutdown is still in progress) + */ + public synchronized boolean beginShutdown() { + markForShutdown(ShutdownReason.ZOMBIE); + checkAndSubmitNextTask(); + + return isShutdown(); + } + + synchronized void markForShutdown(ShutdownReason reason) { + // ShutdownReason.ZOMBIE takes precedence over TERMINATE (we won't be able to save checkpoint at end of shard) + if (shutdownReason == null || shutdownReason.canTransitionTo(reason)) { + shutdownReason = reason; + } + } + + /** + * Used (by Worker) to check if this ShardConsumer instance has been shutdown + * RecordProcessor shutdown() has been invoked, as appropriate. + * + * @return true if shutdown is complete + */ + public boolean isShutdown() { + return currentState.isTerminal(); + } + + /** + * @return the shutdownReason + */ + public ShutdownReason getShutdownReason() { + return shutdownReason; + } + @VisibleForTesting public boolean isShutdownRequested() { return shutdownReason != null; } + // + + // /** * Private/Internal method - has package level access solely for testing purposes. * @@ -504,4 +525,46 @@ public class ShardConsumer { ShutdownNotification getShutdownNotification() { return shutdownNotification; } + // + + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future taskResult = null; + + // + @Override + public void started(Started started) { + if (taskResult != null) { + try { + taskResult.get(); + } catch (InterruptedException e) { + e.printStackTrace(); + } catch (ExecutionException e) { + e.printStackTrace(); + } + } + + taskResult = executor.submit(() -> recordProcessorLifecycle.started(started)); + } + + @Override + public void recordsReceived(ProcessRecordsInput records) { + + } + + @Override + public void leaseLost(LeaseLost leaseLost) { + + } + + @Override + public void shardCompleted(ShardCompleted shardCompletedInput) { + + } + + @Override + public void shutdownRequested(ShutdownRequested shutdownRequested) { + + } + // } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/LeaseLost.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/LeaseLost.java new file mode 100644 index 00000000..912f2966 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/LeaseLost.java @@ -0,0 +1,18 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle.events; + +public class LeaseLost { +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/RecordsReceived.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/RecordsReceived.java new file mode 100644 index 00000000..15dc0cc6 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/RecordsReceived.java @@ -0,0 +1,22 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle.events; + +import lombok.Data; + +@Data +public class RecordsReceived { + +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ShardCompleted.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ShardCompleted.java new file mode 100644 index 00000000..1df45a56 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ShardCompleted.java @@ -0,0 +1,18 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle.events; + +public class ShardCompleted { +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ShutdownRequested.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ShutdownRequested.java new file mode 100644 index 00000000..aa9074bd --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/ShutdownRequested.java @@ -0,0 +1,18 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle.events; + +public class ShutdownRequested { +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/Started.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/Started.java new file mode 100644 index 00000000..80943ad4 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/events/Started.java @@ -0,0 +1,33 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle.events; + +import lombok.Data; +import software.amazon.kinesis.lifecycle.InitializationInput; +import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; + +@Data +public class Started { + + private final String shardId; + private final ExtendedSequenceNumber sequenceNumber; + private final ExtendedSequenceNumber pendingSequenceNumber; + + public InitializationInput toInitializationInput() { + return new InitializationInput().withShardId(shardId).withExtendedSequenceNumber(sequenceNumber) + .withExtendedSequenceNumber(sequenceNumber); + } + +} diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ProcessTaskTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ProcessTaskTest.java index 4a97d347..4d6233a8 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ProcessTaskTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ProcessTaskTest.java @@ -1,26 +1,22 @@ /* - * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Amazon Software License (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/asl/ - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. + * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Amazon Software License + * (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at + * http://aws.amazon.com/asl/ or in the "license" file accompanying this file. This file is distributed on an "AS IS" + * BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. */ package software.amazon.kinesis.lifecycle; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.beans.HasPropertyWithValue.hasProperty; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; +import static org.junit.Assert.assertThat; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -29,7 +25,6 @@ import java.math.BigInteger; import java.nio.ByteBuffer; import java.security.MessageDigest; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.List; @@ -37,34 +32,47 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.TimeUnit; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStreamExtended; -import software.amazon.kinesis.coordinator.KinesisClientLibConfiguration; -import software.amazon.kinesis.coordinator.RecordProcessorCheckpointer; -import software.amazon.kinesis.leases.ShardInfo; -import software.amazon.kinesis.coordinator.StreamConfig; -import software.amazon.kinesis.retrieval.ThrottlingReporter; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeDiagnosingMatcher; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.runners.MockitoJUnitRunner; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStreamExtended; +import com.amazonaws.services.kinesis.model.Record; +import com.google.protobuf.ByteString; + +import lombok.Data; +import software.amazon.kinesis.coordinator.KinesisClientLibConfiguration; +import software.amazon.kinesis.coordinator.RecordProcessorCheckpointer; +import software.amazon.kinesis.coordinator.StreamConfig; +import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.processor.IRecordProcessor; import software.amazon.kinesis.retrieval.GetRecordsCache; import software.amazon.kinesis.retrieval.KinesisDataFetcher; +import software.amazon.kinesis.retrieval.ThrottlingReporter; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; import software.amazon.kinesis.retrieval.kpl.Messages; import software.amazon.kinesis.retrieval.kpl.Messages.AggregatedRecord; import software.amazon.kinesis.retrieval.kpl.UserRecord; -import com.amazonaws.services.kinesis.model.ProvisionedThroughputExceededException; -import com.amazonaws.services.kinesis.model.Record; -import com.google.protobuf.ByteString; +@RunWith(MockitoJUnitRunner.class) public class ProcessTaskTest { + private StreamConfig config; + private ShardInfo shardInfo; + + @Mock + private ProcessRecordsInput processRecordsInput; + @SuppressWarnings("serial") - private static class RecordSubclass extends Record {} + private static class RecordSubclass extends Record { + } private static final byte[] TEST_DATA = new byte[] { 1, 2, 3, 4 }; @@ -75,78 +83,45 @@ public class ProcessTaskTest { private final boolean callProcessRecordsForEmptyRecordList = true; // We don't want any of these tests to run checkpoint validation private final boolean skipCheckpointValidationValue = false; - private static final InitialPositionInStreamExtended INITIAL_POSITION_LATEST = - InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST); + private static final InitialPositionInStreamExtended INITIAL_POSITION_LATEST = InitialPositionInStreamExtended + .newInitialPosition(InitialPositionInStream.LATEST); - private @Mock - KinesisDataFetcher mockDataFetcher; - private @Mock IRecordProcessor mockRecordProcessor; - private @Mock - RecordProcessorCheckpointer mockCheckpointer; + @Mock + private KinesisDataFetcher mockDataFetcher; + @Mock + private IRecordProcessor mockRecordProcessor; + @Mock + private RecordProcessorCheckpointer mockCheckpointer; @Mock private ThrottlingReporter throttlingReporter; @Mock private GetRecordsCache getRecordsCache; - private List processedRecords; - private ExtendedSequenceNumber newLargestPermittedCheckpointValue; - private ProcessTask processTask; @Before public void setUpProcessTask() { - // Initialize the annotation - MockitoAnnotations.initMocks(this); // Set up process task - final StreamConfig config = - new StreamConfig(null, maxRecords, idleTimeMillis, callProcessRecordsForEmptyRecordList, - skipCheckpointValidationValue, - INITIAL_POSITION_LATEST); - final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null); - processTask = new ProcessTask( - shardInfo, - config, - mockRecordProcessor, - mockCheckpointer, - mockDataFetcher, - taskBackoffTimeMillis, - KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, - throttlingReporter, - getRecordsCache); + config = new StreamConfig(null, maxRecords, idleTimeMillis, callProcessRecordsForEmptyRecordList, + skipCheckpointValidationValue, INITIAL_POSITION_LATEST); + shardInfo = new ShardInfo(shardId, null, null, null); + } - @Test - public void testProcessTaskWithProvisionedThroughputExceededException() { - // Set data fetcher to throw exception - doReturn(false).when(mockDataFetcher).isShardEndReached(); - doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(getRecordsCache) - .getNextResult(); - - TaskResult result = processTask.call(); - verify(throttlingReporter).throttled(); - verify(throttlingReporter, never()).success(); - verify(getRecordsCache).getNextResult(); - assertTrue("Result should contain ProvisionedThroughputExceededException", - result.getException() instanceof ProvisionedThroughputExceededException); - } - - @Test - public void testProcessTaskWithNonExistentStream() { - // Data fetcher returns a null Result ` the stream does not exist - doReturn(new ProcessRecordsInput().withRecords(Collections.emptyList()).withMillisBehindLatest((long) 0)).when(getRecordsCache).getNextResult(); - - TaskResult result = processTask.call(); - verify(getRecordsCache).getNextResult(); - assertNull("Task should not throw an exception", result.getException()); + private ProcessTask makeProcessTask(ProcessRecordsInput processRecordsInput) { + return new ProcessTask(shardInfo, config, mockRecordProcessor, mockCheckpointer, taskBackoffTimeMillis, + KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, + processRecordsInput); } @Test public void testProcessTaskWithShardEndReached() { - // Set data fetcher to return true for shard end reached - doReturn(true).when(mockDataFetcher).isShardEndReached(); + + processTask = makeProcessTask(processRecordsInput); + when(processRecordsInput.isAtShardEnd()).thenReturn(true); TaskResult result = processTask.call(); - assertTrue("Result should contain shardEndReached true", result.isShardEndReached()); + assertThat(result, shardEndTaskResult(true)); } @Test @@ -154,41 +129,42 @@ public class ProcessTaskTest { final String sqn = new BigInteger(128, new Random()).toString(); final String pk = UUID.randomUUID().toString(); final Date ts = new Date(System.currentTimeMillis() - TimeUnit.MILLISECONDS.convert(4, TimeUnit.HOURS)); - final Record r = new Record() - .withPartitionKey(pk) - .withData(ByteBuffer.wrap(TEST_DATA)) - .withSequenceNumber(sqn) + final Record r = new Record().withPartitionKey(pk).withData(ByteBuffer.wrap(TEST_DATA)).withSequenceNumber(sqn) .withApproximateArrivalTimestamp(ts); - testWithRecord(r); + RecordProcessorOutcome outcome = testWithRecord(r); - assertEquals(1, processedRecords.size()); + assertEquals(1, outcome.getProcessRecordsCall().getRecords().size()); - Record pr = processedRecords.get(0); + Record pr = outcome.getProcessRecordsCall().getRecords().get(0); assertEquals(pk, pr.getPartitionKey()); assertEquals(ts, pr.getApproximateArrivalTimestamp()); - byte[] b = new byte[pr.getData().remaining()]; - pr.getData().get(b); - assertTrue(Arrays.equals(TEST_DATA, b)); + byte[] b = pr.getData().array(); + assertThat(b, equalTo(TEST_DATA)); - assertEquals(sqn, newLargestPermittedCheckpointValue.getSequenceNumber()); - assertEquals(0, newLargestPermittedCheckpointValue.getSubSequenceNumber()); + assertEquals(sqn, outcome.getCheckpointCall().getSequenceNumber()); + assertEquals(0, outcome.getCheckpointCall().getSubSequenceNumber()); + } + + @Data + static class RecordProcessorOutcome { + final ProcessRecordsInput processRecordsCall; + final ExtendedSequenceNumber checkpointCall; } @Test public void testDoesNotDeaggregateSubclassOfRecord() { final String sqn = new BigInteger(128, new Random()).toString(); - final Record r = new RecordSubclass() - .withSequenceNumber(sqn) - .withData(ByteBuffer.wrap(new byte[0])); + final Record r = new RecordSubclass().withSequenceNumber(sqn).withData(ByteBuffer.wrap(new byte[0])); - testWithRecord(r); + processTask = makeProcessTask(processRecordsInput); + RecordProcessorOutcome outcome = testWithRecord(r); - assertEquals(1, processedRecords.size(), 1); - assertSame(r, processedRecords.get(0)); + assertEquals(1, outcome.getProcessRecordsCall().getRecords().size(), 1); + assertSame(r, outcome.getProcessRecordsCall().getRecords().get(0)); - assertEquals(sqn, newLargestPermittedCheckpointValue.getSequenceNumber()); - assertEquals(0, newLargestPermittedCheckpointValue.getSubSequenceNumber()); + assertEquals(sqn, outcome.getCheckpointCall().getSequenceNumber()); + assertEquals(0, outcome.getCheckpointCall().getSubSequenceNumber()); } @Test @@ -196,44 +172,44 @@ public class ProcessTaskTest { final String sqn = new BigInteger(128, new Random()).toString(); final String pk = UUID.randomUUID().toString(); final Date ts = new Date(System.currentTimeMillis() - TimeUnit.MILLISECONDS.convert(4, TimeUnit.HOURS)); - final Record r = new Record() - .withPartitionKey("-") - .withData(generateAggregatedRecord(pk)) - .withSequenceNumber(sqn) - .withApproximateArrivalTimestamp(ts); + final Record r = new Record().withPartitionKey("-").withData(generateAggregatedRecord(pk)) + .withSequenceNumber(sqn).withApproximateArrivalTimestamp(ts); - testWithRecord(r); + processTask = makeProcessTask(processRecordsInput); + RecordProcessorOutcome outcome = testWithRecord(r); - assertEquals(3, processedRecords.size()); - for (Record pr : processedRecords) { - assertTrue(pr instanceof UserRecord); + List actualRecords = outcome.getProcessRecordsCall().getRecords(); + + assertEquals(3, actualRecords.size()); + for (Record pr : actualRecords) { + assertThat(pr, instanceOf(UserRecord.class)); assertEquals(pk, pr.getPartitionKey()); assertEquals(ts, pr.getApproximateArrivalTimestamp()); - byte[] b = new byte[pr.getData().remaining()]; - pr.getData().get(b); - assertTrue(Arrays.equals(TEST_DATA, b)); + byte[] b = pr.getData().array(); + assertThat(b, equalTo(TEST_DATA)); } - assertEquals(sqn, newLargestPermittedCheckpointValue.getSequenceNumber()); - assertEquals(processedRecords.size() - 1, newLargestPermittedCheckpointValue.getSubSequenceNumber()); + assertEquals(sqn, outcome.getCheckpointCall().getSequenceNumber()); + assertEquals(actualRecords.size() - 1, outcome.getCheckpointCall().getSubSequenceNumber()); } @Test public void testDeaggregatesRecordWithNoArrivalTimestamp() { final String sqn = new BigInteger(128, new Random()).toString(); final String pk = UUID.randomUUID().toString(); - final Record r = new Record() - .withPartitionKey("-") - .withData(generateAggregatedRecord(pk)) + final Record r = new Record().withPartitionKey("-").withData(generateAggregatedRecord(pk)) .withSequenceNumber(sqn); - testWithRecord(r); + processTask = makeProcessTask(processRecordsInput); + RecordProcessorOutcome outcome = testWithRecord(r); - assertEquals(3, processedRecords.size()); - for (Record pr : processedRecords) { - assertTrue(pr instanceof UserRecord); + List actualRecords = outcome.getProcessRecordsCall().getRecords(); + + assertEquals(3, actualRecords.size()); + for (Record pr : actualRecords) { + assertThat(pr, instanceOf(UserRecord.class)); assertEquals(pk, pr.getPartitionKey()); - assertNull(pr.getApproximateArrivalTimestamp()); + assertThat(pr.getApproximateArrivalTimestamp(), nullValue()); } } @@ -246,15 +222,17 @@ public class ProcessTaskTest { final int numberOfRecords = 104; // Start these batch of records's sequence number that is greater than previous checkpoint value. final BigInteger startingSqn = previousCheckpointSqn.add(BigInteger.valueOf(10)); - final List records = generateConsecutiveRecords( - numberOfRecords, "-", ByteBuffer.wrap(TEST_DATA), new Date(), startingSqn); + final List records = generateConsecutiveRecords(numberOfRecords, "-", ByteBuffer.wrap(TEST_DATA), + new Date(), startingSqn); - testWithRecords(records, new ExtendedSequenceNumber(previousCheckpointSqn.toString()), + processTask = makeProcessTask(processRecordsInput); + RecordProcessorOutcome outcome = testWithRecords(records, + new ExtendedSequenceNumber(previousCheckpointSqn.toString()), new ExtendedSequenceNumber(previousCheckpointSqn.toString())); final ExtendedSequenceNumber expectedLargestPermittedEsqn = new ExtendedSequenceNumber( startingSqn.add(BigInteger.valueOf(numberOfRecords - 1)).toString()); - assertEquals(expectedLargestPermittedEsqn, newLargestPermittedCheckpointValue); + assertEquals(expectedLargestPermittedEsqn, outcome.getCheckpointCall()); } @Test @@ -265,17 +243,19 @@ public class ProcessTaskTest { final ExtendedSequenceNumber largestPermittedEsqn = new ExtendedSequenceNumber( baseSqn.add(BigInteger.valueOf(100)).toString()); - testWithRecords(Collections.emptyList(), lastCheckpointEspn, largestPermittedEsqn); + processTask = makeProcessTask(processRecordsInput); + RecordProcessorOutcome outcome = testWithRecords(Collections.emptyList(), lastCheckpointEspn, + largestPermittedEsqn); // Make sure that even with empty records, largest permitted sequence number does not change. - assertEquals(largestPermittedEsqn, newLargestPermittedCheckpointValue); + assertEquals(largestPermittedEsqn, outcome.getCheckpointCall()); } @Test public void testFilterBasedOnLastCheckpointValue() { // Explanation of setup: // * Assume in previous processRecord call, user got 3 sub-records that all belonged to one - // Kinesis record. So sequence number was X, and sub-sequence numbers were 0, 1, 2. + // Kinesis record. So sequence number was X, and sub-sequence numbers were 0, 1, 2. // * 2nd sub-record was checkpointed (extended sequnce number X.1). // * Worker crashed and restarted. So now DDB has checkpoint value of X.1. // Test: @@ -286,21 +266,22 @@ public class ProcessTaskTest { // Values for this processRecords call. final String startingSqn = previousCheckpointSqn.toString(); final String pk = UUID.randomUUID().toString(); - final Record r = new Record() - .withPartitionKey("-") - .withData(generateAggregatedRecord(pk)) + final Record r = new Record().withPartitionKey("-").withData(generateAggregatedRecord(pk)) .withSequenceNumber(startingSqn); - testWithRecords(Collections.singletonList(r), + processTask = makeProcessTask(processRecordsInput); + RecordProcessorOutcome outcome = testWithRecords(Collections.singletonList(r), new ExtendedSequenceNumber(previousCheckpointSqn.toString(), previousCheckpointSsqn), new ExtendedSequenceNumber(previousCheckpointSqn.toString(), previousCheckpointSsqn)); + List actualRecords = outcome.getProcessRecordsCall().getRecords(); + // First two records should be dropped - and only 1 remaining records should be there. - assertEquals(1, processedRecords.size()); - assertTrue(processedRecords.get(0) instanceof UserRecord); + assertEquals(1, actualRecords.size()); + assertThat(actualRecords.get(0), instanceOf(UserRecord.class)); // Verify user record's extended sequence number and other fields. - final UserRecord pr = (UserRecord)processedRecords.get(0); + final UserRecord pr = (UserRecord) actualRecords.get(0); assertEquals(pk, pr.getPartitionKey()); assertEquals(startingSqn, pr.getSequenceNumber()); assertEquals(previousCheckpointSsqn + 1, pr.getSubSequenceNumber()); @@ -309,60 +290,50 @@ public class ProcessTaskTest { // Expected largest permitted sequence number will be last sub-record sequence number. final ExtendedSequenceNumber expectedLargestPermittedEsqn = new ExtendedSequenceNumber( previousCheckpointSqn.toString(), 2L); - assertEquals(expectedLargestPermittedEsqn, newLargestPermittedCheckpointValue); + assertEquals(expectedLargestPermittedEsqn, outcome.getCheckpointCall()); } - private void testWithRecord(Record record) { - testWithRecords(Collections.singletonList(record), - ExtendedSequenceNumber.TRIM_HORIZON, ExtendedSequenceNumber.TRIM_HORIZON); + private RecordProcessorOutcome testWithRecord(Record record) { + return testWithRecords(Collections.singletonList(record), ExtendedSequenceNumber.TRIM_HORIZON, + ExtendedSequenceNumber.TRIM_HORIZON); } - private void testWithRecords(List records, - ExtendedSequenceNumber lastCheckpointValue, + private RecordProcessorOutcome testWithRecords(List records, ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber largestPermittedCheckpointValue) { - when(getRecordsCache.getNextResult()).thenReturn(new ProcessRecordsInput().withRecords(records).withMillisBehindLatest((long) 1000 * 50)); when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue); when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue); + when(processRecordsInput.getRecords()).thenReturn(records); + processTask = makeProcessTask(processRecordsInput); processTask.call(); verify(throttlingReporter).success(); verify(throttlingReporter, never()).throttled(); - verify(getRecordsCache).getNextResult(); - ArgumentCaptor priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); - verify(mockRecordProcessor).processRecords(priCaptor.capture()); - processedRecords = priCaptor.getValue().getRecords(); + ArgumentCaptor recordsCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); + verify(mockRecordProcessor).processRecords(recordsCaptor.capture()); ArgumentCaptor esnCaptor = ArgumentCaptor.forClass(ExtendedSequenceNumber.class); verify(mockCheckpointer).setLargestPermittedCheckpointValue(esnCaptor.capture()); - newLargestPermittedCheckpointValue = esnCaptor.getValue(); + + return new RecordProcessorOutcome(recordsCaptor.getValue(), esnCaptor.getValue()); + } /** - * See the KPL documentation on GitHub for more details about the binary - * format. + * See the KPL documentation on GitHub for more details about the binary format. * * @param pk - * Partition key to use. All the records will have the same - * partition key. - * @return ByteBuffer containing the serialized form of the aggregated - * record, along with the necessary header and footer. + * Partition key to use. All the records will have the same partition key. + * @return ByteBuffer containing the serialized form of the aggregated record, along with the necessary header and + * footer. */ private static ByteBuffer generateAggregatedRecord(String pk) { ByteBuffer bb = ByteBuffer.allocate(1024); - bb.put(new byte[] {-13, -119, -102, -62 }); + bb.put(new byte[] { -13, -119, -102, -62 }); - Messages.Record r = - Messages.Record.newBuilder() - .setData(ByteString.copyFrom(TEST_DATA)) - .setPartitionKeyIndex(0) - .build(); + Messages.Record r = Messages.Record.newBuilder().setData(ByteString.copyFrom(TEST_DATA)).setPartitionKeyIndex(0) + .build(); - byte[] payload = AggregatedRecord.newBuilder() - .addPartitionKeyTable(pk) - .addRecords(r) - .addRecords(r) - .addRecords(r) - .build() - .toByteArray(); + byte[] payload = AggregatedRecord.newBuilder().addPartitionKeyTable(pk).addRecords(r).addRecords(r) + .addRecords(r).build().toByteArray(); bb.put(payload); bb.put(md5(payload)); @@ -371,16 +342,13 @@ public class ProcessTaskTest { return bb; } - private static List generateConsecutiveRecords( - int numberOfRecords, String partitionKey, ByteBuffer data, + private static List generateConsecutiveRecords(int numberOfRecords, String partitionKey, ByteBuffer data, Date arrivalTimestamp, BigInteger startSequenceNumber) { List records = new ArrayList<>(); - for (int i = 0 ; i < numberOfRecords ; ++i) { - records.add(new Record() - .withPartitionKey(partitionKey) - .withData(data) - .withSequenceNumber(startSequenceNumber.add(BigInteger.valueOf(i)).toString()) - .withApproximateArrivalTimestamp(arrivalTimestamp)); + for (int i = 0; i < numberOfRecords; ++i) { + records.add(new Record().withPartitionKey(partitionKey).withData(data) + .withSequenceNumber(startSequenceNumber.add(BigInteger.valueOf(i)).toString()) + .withApproximateArrivalTimestamp(arrivalTimestamp)); } return records; } @@ -393,4 +361,48 @@ public class ProcessTaskTest { throw new RuntimeException(e); } } + + private static TaskResultMatcher shardEndTaskResult(boolean isAtShardEnd) { + TaskResult expected = new TaskResult(null, isAtShardEnd); + return taskResult(expected); + } + + private static TaskResultMatcher exceptionTaskResult(Exception ex) { + TaskResult expected = new TaskResult(ex, false); + return taskResult(expected); + } + + private static TaskResultMatcher taskResult(TaskResult expected) { + return new TaskResultMatcher(expected); + } + + private static class TaskResultMatcher extends TypeSafeDiagnosingMatcher { + + Matcher matchers; + + TaskResultMatcher(TaskResult expected) { + if (expected == null) { + matchers = nullValue(TaskResult.class); + } else { + matchers = allOf(notNullValue(TaskResult.class), + hasProperty("shardEndReached", equalTo(expected.isShardEndReached())), + hasProperty("exception", equalTo(expected.getException()))); + } + + } + + @Override + protected boolean matchesSafely(TaskResult item, Description mismatchDescription) { + if (!matchers.matches(item)) { + matchers.describeMismatch(item, mismatchDescription); + return false; + } + return true; + } + + @Override + public void describeTo(Description description) { + description.appendDescriptionOf(matchers); + } + } }