diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java new file mode 100644 index 00000000..6290dd4f --- /dev/null +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java @@ -0,0 +1,126 @@ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; +import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingScope; +import com.amazonaws.services.kinesis.model.GetRecordsResult; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import lombok.NonNull; +import lombok.extern.apachecommons.CommonsLog; + +/** + * + */ +@CommonsLog +public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy { + private static final int TIME_TO_KEEP_ALIVE = 5; + private static final int CORE_THREAD_POOL_COUNT = 1; + + private final KinesisDataFetcher dataFetcher; + private final ExecutorService executorService; + private final int retryGetRecordsInSeconds; + private final String shardId; + final CompletionService completionService; + + public AsynchronousGetRecordsRetrievalStrategy(@NonNull final KinesisDataFetcher dataFetcher, + final int retryGetRecordsInSeconds, final int maxGetRecordsThreadPool, String shardId) { + this(dataFetcher, buildExector(maxGetRecordsThreadPool, shardId), retryGetRecordsInSeconds, shardId); + } + + public AsynchronousGetRecordsRetrievalStrategy(final KinesisDataFetcher dataFetcher, + final ExecutorService executorService, final int retryGetRecordsInSeconds, String shardId) { + this(dataFetcher, executorService, retryGetRecordsInSeconds, new ExecutorCompletionService<>(executorService), + shardId); + } + + AsynchronousGetRecordsRetrievalStrategy(KinesisDataFetcher dataFetcher, ExecutorService executorService, + int retryGetRecordsInSeconds, CompletionService completionService, String shardId) { + this.dataFetcher = dataFetcher; + this.executorService = executorService; + this.retryGetRecordsInSeconds = retryGetRecordsInSeconds; + this.completionService = completionService; + this.shardId = shardId; + } + + @Override + public GetRecordsResult getRecords(final int maxRecords) { + if (executorService.isShutdown()) { + throw new IllegalStateException("Strategy has been shutdown"); + } + GetRecordsResult result = null; + Set> futures = new HashSet<>(); + Callable retrieverCall = createRetrieverCallable(maxRecords); + while (true) { + try { + futures.add(completionService.submit(retrieverCall)); + } catch (RejectedExecutionException e) { + log.warn("Out of resources, unable to start additional requests."); + } + + try { + Future resultFuture = completionService.poll(retryGetRecordsInSeconds, + TimeUnit.SECONDS); + if (resultFuture != null) { + result = resultFuture.get(); + break; + } + } catch (ExecutionException e) { + log.error("ExecutionException thrown while trying to get records", e); + } catch (InterruptedException e) { + log.error("Thread was interrupted", e); + break; + } + } + futures.stream().peek(f -> f.cancel(true)).filter(Future::isCancelled).forEach(f -> { + try { + completionService.take(); + } catch (InterruptedException e) { + log.error("Exception thrown while trying to empty the threadpool."); + } + }); + return result; + } + + private Callable createRetrieverCallable(int maxRecords) { + ThreadSafeMetricsDelegatingScope metricsScope = new ThreadSafeMetricsDelegatingScope(MetricsHelper.getMetricsScope()); + return () -> { + try { + MetricsHelper.setMetricsScope(metricsScope); + return dataFetcher.getRecords(maxRecords); + } finally { + MetricsHelper.unsetMetricsScope(); + } + }; + } + + @Override + public void shutdown() { + executorService.shutdownNow(); + } + + @Override + public boolean isShutdown() { + return executorService.isShutdown(); + } + + private static ExecutorService buildExector(int maxGetRecordsThreadPool, String shardId) { + String threadNameFormat = "get-records-worker-" + shardId + "-%d"; + return new ThreadPoolExecutor(CORE_THREAD_POOL_COUNT, maxGetRecordsThreadPool, TIME_TO_KEEP_ALIVE, + TimeUnit.SECONDS, new LinkedBlockingQueue<>(1), + new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadNameFormat).build(), + new ThreadPoolExecutor.AbortPolicy()); + } +} diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStates.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStates.java index d967b2c3..f6d96b4d 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStates.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStates.java @@ -14,6 +14,8 @@ */ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; +import java.util.Optional; + /** * Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks, * and state transitions is contained within the {@link ConsumerState} objects. @@ -309,7 +311,8 @@ class ConsumerStates { public ITask createTask(ShardConsumer consumer) { return new ProcessTask(consumer.getShardInfo(), consumer.getStreamConfig(), consumer.getRecordProcessor(), consumer.getRecordProcessorCheckpointer(), consumer.getDataFetcher(), - consumer.getTaskBackoffTimeMillis(), consumer.isSkipShardSyncAtWorkerInitializationIfLeasesExist()); + consumer.getTaskBackoffTimeMillis(), consumer.isSkipShardSyncAtWorkerInitializationIfLeasesExist(), + consumer.getRetryGetRecordsInSeconds(), consumer.getMaxGetRecordsThreadPool()); } @Override diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java new file mode 100644 index 00000000..a391ac59 --- /dev/null +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java @@ -0,0 +1,33 @@ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import com.amazonaws.services.kinesis.model.GetRecordsResult; + +/** + * Represents a strategy to retrieve records from Kinesis. Allows for variations on how records are retrieved from + * Kinesis. + */ +public interface GetRecordsRetrievalStrategy { + /** + * Gets a set of records from Kinesis. + * + * @param maxRecords + * passed to Kinesis, and can be used to restrict the number of records returned from Kinesis. + * @return the resulting records. + * @throws IllegalStateException + * if the strategy has been shutdown. + */ + GetRecordsResult getRecords(int maxRecords); + + /** + * Releases any resources used by the strategy. Once the strategy is shutdown it is no longer safe to call + * {@link #getRecords(int)}. + */ + void shutdown(); + + /** + * Returns whether this strategy has been shutdown. + * + * @return true if the strategy has been shutdown, false otherwise. + */ + boolean isShutdown(); +} diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java index e9673414..62d87f30 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java @@ -217,6 +217,12 @@ public class KinesisClientLibConfiguration { @Getter private Optional timeoutInSeconds = Optional.empty(); + @Getter + private Optional retryGetRecordsInSeconds = Optional.empty(); + + @Getter + private Optional maxGetRecordsThreadPool = Optional.empty(); + @Getter private int maxLeaseRenewalThreads = DEFAULT_MAX_LEASE_RENEWAL_THREADS; @@ -1111,6 +1117,27 @@ public class KinesisClientLibConfiguration { return this; } + + /** + * @param retryGetRecordsInSeconds the time in seconds to wait before the worker retries to get a record. + * @return this configuration object. + */ + public KinesisClientLibConfiguration withRetryGetRecordsInSeconds(final int retryGetRecordsInSeconds) { + checkIsValuePositive("retryGetRecordsInSeconds", retryGetRecordsInSeconds); + this.retryGetRecordsInSeconds = Optional.of(retryGetRecordsInSeconds); + return this; + } + + /** + *@param maxGetRecordsThreadPool the max number of threads in the getRecords thread pool. + *@return this configuration object + */ + public KinesisClientLibConfiguration withMaxGetRecordsThreadPool(final int maxGetRecordsThreadPool) { + checkIsValuePositive("maxGetRecordsThreadPool", maxGetRecordsThreadPool); + this.maxGetRecordsThreadPool = Optional.of(maxGetRecordsThreadPool); + return this; + } + /** * @param timeoutInSeconds The timeout in seconds to wait for the MultiLangProtocol to wait for */ diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTask.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTask.java index c419c693..223236f6 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTask.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTask.java @@ -18,6 +18,7 @@ import java.math.BigInteger; import java.util.Collections; import java.util.List; import java.util.ListIterator; +import java.util.Optional; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -62,6 +63,19 @@ class ProcessTask implements ITask { private final Shard shard; private final ThrottlingReporter throttlingReporter; + private final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy; + + private static final GetRecordsRetrievalStrategy makeStrategy(KinesisDataFetcher dataFetcher, + Optional retryGetRecordsInSeconds, + Optional maxGetRecordsThreadPool, + ShardInfo shardInfo) { + Optional getRecordsRetrievalStrategy = retryGetRecordsInSeconds.flatMap(retry -> + maxGetRecordsThreadPool.map(max -> + new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, retry, max, shardInfo.getShardId()))); + + return getRecordsRetrievalStrategy.orElse(new SynchronousGetRecordsRetrievalStrategy(dataFetcher)); + } + /** * @param shardInfo * contains information about the shard @@ -77,11 +91,38 @@ class ProcessTask implements ITask { * backoff time when catching exceptions */ public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, - RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, - long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) { + RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, + long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) { + this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis, + skipShardSyncAtWorkerInitializationIfLeasesExist, Optional.empty(), Optional.empty()); + } + + /** + * @param shardInfo + * contains information about the shard + * @param streamConfig + * Stream configuration + * @param recordProcessor + * Record processor used to process the data records for the shard + * @param recordProcessorCheckpointer + * Passed to the RecordProcessor so it can checkpoint progress + * @param dataFetcher + * Kinesis data fetcher (used to fetch records from Kinesis) + * @param backoffTimeMillis + * backoff time when catching exceptions + * @param retryGetRecordsInSeconds + * time in seconds to wait before the worker retries to get a record. + * @param maxGetRecordsThreadPool + * max number of threads in the getRecords thread pool. + */ + public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, + RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, + long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist, + Optional retryGetRecordsInSeconds, Optional maxGetRecordsThreadPool) { this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis, skipShardSyncAtWorkerInitializationIfLeasesExist, - new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId())); + new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), + makeStrategy(dataFetcher, retryGetRecordsInSeconds, maxGetRecordsThreadPool, shardInfo)); } /** @@ -103,7 +144,7 @@ class ProcessTask implements ITask { public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist, - ThrottlingReporter throttlingReporter) { + ThrottlingReporter throttlingReporter, GetRecordsRetrievalStrategy getRecordsRetrievalStrategy) { super(); this.shardInfo = shardInfo; this.recordProcessor = recordProcessor; @@ -113,6 +154,7 @@ class ProcessTask implements ITask { this.backoffTimeMillis = backoffTimeMillis; this.throttlingReporter = throttlingReporter; IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy(); + this.getRecordsRetrievalStrategy = getRecordsRetrievalStrategy; // If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for // this ProcessTask. In this case, duplicate KPL user records in the event of resharding will // not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if @@ -368,7 +410,7 @@ class ProcessTask implements ITask { * @return list of data records from Kinesis */ private GetRecordsResult getRecordsResultAndRecordMillisBehindLatest() { - final GetRecordsResult getRecordsResult = dataFetcher.getRecords(streamConfig.getMaxRecords()); + final GetRecordsResult getRecordsResult = getRecordsRetrievalStrategy.getRecords(streamConfig.getMaxRecords()); if (getRecordsResult == null) { // Stream no longer exists diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java index 63cce40d..70a81fbc 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java @@ -15,10 +15,12 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import lombok.Getter; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -53,6 +55,10 @@ class ShardConsumer { private final boolean cleanupLeasesOfCompletedShards; private final long taskBackoffTimeMillis; private final boolean skipShardSyncAtWorkerInitializationIfLeasesExist; + @Getter + private final Optional retryGetRecordsInSeconds; + @Getter + private final Optional maxGetRecordsThreadPool; private ITask currentTask; private long currentTaskSubmitTime; @@ -93,6 +99,38 @@ class ShardConsumer { IMetricsFactory metricsFactory, long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) { + this(shardInfo, streamConfig, checkpoint,recordProcessor, leaseManager, parentShardPollIntervalMillis, + cleanupLeasesOfCompletedShards, executorService, metricsFactory, backoffTimeMillis, + skipShardSyncAtWorkerInitializationIfLeasesExist, Optional.empty(), Optional.empty()); + } + + /** + * @param shardInfo Shard information + * @param streamConfig Stream configuration to use + * @param checkpoint Checkpoint tracker + * @param recordProcessor Record processor used to process the data records for the shard + * @param leaseManager Used to create leases for new shards + * @param parentShardPollIntervalMillis Wait for this long if parent shards are not done (or we get an exception) + * @param executorService ExecutorService used to execute process tasks for this shard + * @param metricsFactory IMetricsFactory used to construct IMetricsScopes for this shard + * @param backoffTimeMillis backoff interval when we encounter exceptions + * @param retryGetRecordsInSeconds time in seconds to wait before the worker retries to get a record. + * @param maxGetRecordsThreadPool max number of threads in the getRecords thread pool. + */ + // CHECKSTYLE:IGNORE ParameterNumber FOR NEXT 10 LINES + ShardConsumer(ShardInfo shardInfo, + StreamConfig streamConfig, + ICheckpoint checkpoint, + IRecordProcessor recordProcessor, + ILeaseManager leaseManager, + long parentShardPollIntervalMillis, + boolean cleanupLeasesOfCompletedShards, + ExecutorService executorService, + IMetricsFactory metricsFactory, + long backoffTimeMillis, + boolean skipShardSyncAtWorkerInitializationIfLeasesExist, + Optional retryGetRecordsInSeconds, + Optional maxGetRecordsThreadPool) { this.streamConfig = streamConfig; this.recordProcessor = recordProcessor; this.executorService = executorService; @@ -111,6 +149,8 @@ class ShardConsumer { this.cleanupLeasesOfCompletedShards = cleanupLeasesOfCompletedShards; this.taskBackoffTimeMillis = backoffTimeMillis; this.skipShardSyncAtWorkerInitializationIfLeasesExist = skipShardSyncAtWorkerInitializationIfLeasesExist; + this.retryGetRecordsInSeconds = retryGetRecordsInSeconds; + this.maxGetRecordsThreadPool = maxGetRecordsThreadPool; } /** diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java new file mode 100644 index 00000000..77a60448 --- /dev/null +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java @@ -0,0 +1,31 @@ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import com.amazonaws.services.kinesis.model.GetRecordsResult; +import lombok.Data; +import lombok.NonNull; + +/** + * + */ +@Data +public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy { + @NonNull + private final KinesisDataFetcher dataFetcher; + + @Override + public GetRecordsResult getRecords(final int maxRecords) { + return dataFetcher.getRecords(maxRecords); + } + + @Override + public void shutdown() { + // + // Does nothing as this retriever doesn't manage any resources + // + } + + @Override + public boolean isShutdown() { + return false; + } +} diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java index fd461e31..3cfb9f2f 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java @@ -17,6 +17,7 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; @@ -85,6 +86,9 @@ public class Worker implements Runnable { private final long taskBackoffTimeMillis; private final long failoverTimeMillis; + private final Optional retryGetRecordsInSeconds; + private final Optional maxGetRecordsThreadPool; + // private final KinesisClientLeaseManager leaseManager; private final KinesisClientLibLeaseCoordinator leaseCoordinator; private final ShardSyncTaskManager controlServer; @@ -266,7 +270,9 @@ public class Worker implements Runnable { config.getTaskBackoffTimeMillis(), config.getFailoverTimeMillis(), config.getSkipShardSyncAtWorkerInitializationIfLeasesExist(), - config.getShardPrioritizationStrategy()); + config.getShardPrioritizationStrategy(), + config.getRetryGetRecordsInSeconds(), + config.getMaxGetRecordsThreadPool()); // If a region name was explicitly specified, use it as the region for Amazon Kinesis and Amazon DynamoDB. if (config.getRegionName() != null) { @@ -333,6 +339,56 @@ public class Worker implements Runnable { KinesisClientLibLeaseCoordinator leaseCoordinator, ExecutorService execService, IMetricsFactory metricsFactory, long taskBackoffTimeMillis, long failoverTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist, ShardPrioritization shardPrioritization) { + this(applicationName, recordProcessorFactory, streamConfig, initialPositionInStream, parentShardPollIntervalMillis, + shardSyncIdleTimeMillis, cleanupLeasesUponShardCompletion, checkpoint, leaseCoordinator, execService, + metricsFactory, taskBackoffTimeMillis, failoverTimeMillis, skipShardSyncAtWorkerInitializationIfLeasesExist, + shardPrioritization, Optional.empty(), Optional.empty()); + } + + + /** + * @param applicationName + * Name of the Kinesis application + * @param recordProcessorFactory + * Used to get record processor instances for processing data from shards + * @param streamConfig + * Stream configuration + * @param initialPositionInStream + * One of LATEST, TRIM_HORIZON, or AT_TIMESTAMP. The KinesisClientLibrary will start fetching data from + * this location in the stream when an application starts up for the first time and there are no + * checkpoints. If there are checkpoints, we start from the checkpoint position. + * @param parentShardPollIntervalMillis + * Wait for this long between polls to check if parent shards are done + * @param shardSyncIdleTimeMillis + * Time between tasks to sync leases and Kinesis shards + * @param cleanupLeasesUponShardCompletion + * Clean up shards we've finished processing (don't wait till they expire in Kinesis) + * @param checkpoint + * Used to get/set checkpoints + * @param leaseCoordinator + * Lease coordinator (coordinates currently owned leases) + * @param execService + * ExecutorService to use for processing records (support for multi-threaded consumption) + * @param metricsFactory + * Metrics factory used to emit metrics + * @param taskBackoffTimeMillis + * Backoff period when tasks encounter an exception + * @param shardPrioritization + * Provides prioritization logic to decide which available shards process first + * @param retryGetRecordsInSeconds + * Time in seconds to wait before the worker retries to get a record. + * @param maxGetRecordsThreadPool + * Max number of threads in the getRecords thread pool. + */ + // NOTE: This has package level access solely for testing + // CHECKSTYLE:IGNORE ParameterNumber FOR NEXT 10 LINES + Worker(String applicationName, IRecordProcessorFactory recordProcessorFactory, StreamConfig streamConfig, + InitialPositionInStreamExtended initialPositionInStream, long parentShardPollIntervalMillis, + long shardSyncIdleTimeMillis, boolean cleanupLeasesUponShardCompletion, ICheckpoint checkpoint, + KinesisClientLibLeaseCoordinator leaseCoordinator, ExecutorService execService, + IMetricsFactory metricsFactory, long taskBackoffTimeMillis, long failoverTimeMillis, + boolean skipShardSyncAtWorkerInitializationIfLeasesExist, ShardPrioritization shardPrioritization, + Optional retryGetRecordsInSeconds, Optional maxGetRecordsThreadPool) { this.applicationName = applicationName; this.recordProcessorFactory = recordProcessorFactory; this.streamConfig = streamConfig; @@ -351,8 +407,11 @@ public class Worker implements Runnable { this.failoverTimeMillis = failoverTimeMillis; this.skipShardSyncAtWorkerInitializationIfLeasesExist = skipShardSyncAtWorkerInitializationIfLeasesExist; this.shardPrioritization = shardPrioritization; + this.retryGetRecordsInSeconds = retryGetRecordsInSeconds; + this.maxGetRecordsThreadPool = maxGetRecordsThreadPool; } + /** * @return the applicationName */ @@ -786,7 +845,7 @@ public class Worker implements Runnable { return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion, executorService, metricsFactory, taskBackoffTimeMillis, - skipShardSyncAtWorkerInitializationIfLeasesExist); + skipShardSyncAtWorkerInitializationIfLeasesExist, retryGetRecordsInSeconds, maxGetRecordsThreadPool); } @@ -1213,7 +1272,9 @@ public class Worker implements Runnable { config.getTaskBackoffTimeMillis(), config.getFailoverTimeMillis(), config.getSkipShardSyncAtWorkerInitializationIfLeasesExist(), - shardPrioritization); + shardPrioritization, + config.getRetryGetRecordsInSeconds(), + config.getMaxGetRecordsThreadPool()); } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/config/KinesisClientLibConfiguratorTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/config/KinesisClientLibConfiguratorTest.java index cbdd0a2d..d16be640 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/config/KinesisClientLibConfiguratorTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/config/KinesisClientLibConfiguratorTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.fail; import java.io.ByteArrayInputStream; import java.io.InputStream; +import java.util.Optional; import java.util.Set; import org.apache.commons.lang.StringUtils; @@ -60,6 +61,8 @@ public class KinesisClientLibConfiguratorTest { assertEquals(config.getApplicationName(), "b"); assertEquals(config.getStreamName(), "a"); assertEquals(config.getWorkerIdentifier(), "123"); + assertEquals(config.getMaxGetRecordsThreadPool(), Optional.empty()); + assertEquals(config.getRetryGetRecordsInSeconds(), Optional.empty()); } @Test @@ -107,7 +110,9 @@ public class KinesisClientLibConfiguratorTest { "workerId = w123", "maxRecords = 10", "metricsMaxQueueSize = 20", - "applicationName = kinesis" + "applicationName = kinesis", + "retryGetRecordsInSeconds = 2", + "maxGetRecordsThreadPool = 1" }, '\n')); assertEquals(config.getApplicationName(), "kinesis"); @@ -115,6 +120,8 @@ public class KinesisClientLibConfiguratorTest { assertEquals(config.getWorkerIdentifier(), "w123"); assertEquals(config.getMaxRecords(), 10); assertEquals(config.getMetricsMaxQueueSize(), 20); + assertEquals(config.getRetryGetRecordsInSeconds(), Optional.of(2)); + assertEquals(config.getMaxGetRecordsThreadPool(), Optional.of(1)); } @Test @@ -202,6 +209,42 @@ public class KinesisClientLibConfiguratorTest { assertEquals(config.getInitialPositionInStream(), InitialPositionInStream.TRIM_HORIZON); } + @Test + public void testEmptyOptionalVariables() { + KinesisClientLibConfiguration config = + getConfiguration(StringUtils.join(new String[] { + "streamName = a", + "applicationName = b", + "AWSCredentialsProvider = ABCD," + credentialName1, + "workerId = 123", + "initialPositionInStream = TriM_Horizon", + "maxGetRecordsThreadPool = 1" + }, '\n')); + assertEquals(config.getMaxGetRecordsThreadPool(), Optional.of(1)); + assertEquals(config.getRetryGetRecordsInSeconds(), Optional.empty()); + } + + @Test + public void testWithZeroValue() { + String test = StringUtils.join(new String[]{ + "streamName = a", + "applicationName = b", + "AWSCredentialsProvider = ABCD," + credentialName1, + "workerId = 123", + "initialPositionInStream = TriM_Horizon", + "maxGetRecordsThreadPool = 0", + "retryGetRecordsInSeconds = 0" + }, '\n'); + InputStream input = new ByteArrayInputStream(test.getBytes()); + + try { + configurator.getConfiguration(input); + } catch (Exception e) { + fail("Don't expect to fail on invalid variable value"); + + } + } + @Test public void testWithInvalidIntValue() { String test = StringUtils.join(new String[] { diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java new file mode 100644 index 00000000..8518c992 --- /dev/null +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java @@ -0,0 +1,156 @@ +/* + * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; +import com.amazonaws.services.kinesis.model.GetRecordsResult; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest { + + private static final int CORE_POOL_SIZE = 1; + private static final int MAX_POOL_SIZE = 2; + private static final int TIME_TO_LIVE = 5; + private static final int RETRY_GET_RECORDS_IN_SECONDS = 2; + private static final int SLEEP_GET_RECORDS_IN_SECONDS = 10; + + @Mock + private IKinesisProxy mockKinesisProxy; + + @Mock + private ShardInfo mockShardInfo; + + private AsynchronousGetRecordsRetrievalStrategy getRecordsRetrivalStrategy; + private KinesisDataFetcher dataFetcher; + private GetRecordsResult result; + private ExecutorService executorService; + private RejectedExecutionHandler rejectedExecutionHandler; + private int numberOfRecords = 10; + private CompletionService completionService; + + @Before + public void setup() { + dataFetcher = spy(new KinesisDataFetcherForTests(mockKinesisProxy, mockShardInfo)); + rejectedExecutionHandler = spy(new ThreadPoolExecutor.AbortPolicy()); + executorService = spy(new ThreadPoolExecutor( + CORE_POOL_SIZE, + MAX_POOL_SIZE, + TIME_TO_LIVE, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>(1), + new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(), + rejectedExecutionHandler)); + getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS, "shardId-0001"); + completionService = spy(getRecordsRetrivalStrategy.completionService); + result = null; + } + + @Test + public void oneRequestMultithreadTest() { + GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); + verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords)); + verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any()); + assertNull(getRecordsResult); + } + + @Test + public void multiRequestTest() { + result = mock(GetRecordsResult.class); + + GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); + verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(numberOfRecords); + verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any()); + assertEquals(result, getRecordsResult); + + result = null; + getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); + assertNull(getRecordsResult); + } + + @Test + @Ignore + public void testInterrupted() throws InterruptedException, ExecutionException { + + Future mockFuture = mock(Future.class); + when(completionService.submit(any())).thenReturn(mockFuture); + when(completionService.poll()).thenReturn(mockFuture); + doThrow(InterruptedException.class).when(mockFuture).get(); + GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); + verify(mockFuture).get(); + assertNull(getRecordsResult); + } + + private int getLeastNumberOfCalls() { + int leastNumberOfCalls = 0; + for (int i = MAX_POOL_SIZE; i > 0; i--) { + if (i * RETRY_GET_RECORDS_IN_SECONDS <= SLEEP_GET_RECORDS_IN_SECONDS) { + leastNumberOfCalls = i; + break; + } + } + return leastNumberOfCalls; + } + + @After + public void shutdown() { + getRecordsRetrivalStrategy.shutdown(); + verify(executorService).shutdownNow(); + } + + private class KinesisDataFetcherForTests extends KinesisDataFetcher { + public KinesisDataFetcherForTests(final IKinesisProxy kinesisProxy, final ShardInfo shardInfo) { + super(kinesisProxy, shardInfo); + } + + @Override + public GetRecordsResult getRecords(final int maxRecords) { + try { + Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000); + } catch (InterruptedException e) { + // Do nothing + } + return result; + } + } + +} diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java new file mode 100644 index 00000000..dfba0351 --- /dev/null +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java @@ -0,0 +1,137 @@ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import com.amazonaws.services.kinesis.model.GetRecordsResult; + +/** + * + */ +@RunWith(MockitoJUnitRunner.class) +public class AsynchronousGetRecordsRetrievalStrategyTest { + + private static final long RETRY_GET_RECORDS_IN_SECONDS = 5; + private static final String SHARD_ID = "ShardId-0001"; + @Mock + private KinesisDataFetcher dataFetcher; + @Mock + private ExecutorService executorService; + @Mock + private CompletionService completionService; + @Mock + private Future successfulFuture; + @Mock + private Future blockedFuture; + @Mock + private GetRecordsResult expectedResults; + + @Test + public void testSingleSuccessfulRequestFuture() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID); + + when(executorService.isShutdown()).thenReturn(false); + when(completionService.submit(any())).thenReturn(successfulFuture); + when(completionService.poll(anyLong(), any())).thenReturn(successfulFuture); + when(successfulFuture.get()).thenReturn(expectedResults); + + GetRecordsResult result = strategy.getRecords(10); + + verify(executorService).isShutdown(); + verify(completionService).submit(any()); + verify(completionService).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS)); + verify(successfulFuture).get(); + verify(successfulFuture).cancel(eq(true)); + verify(successfulFuture).isCancelled(); + verify(completionService, never()).take(); + + assertThat(result, equalTo(expectedResults)); + } + + @Test + public void testBlockedAndSuccessfulFuture() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID); + + when(executorService.isShutdown()).thenReturn(false); + when(completionService.submit(any())).thenReturn(blockedFuture).thenReturn(successfulFuture); + when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture); + when(successfulFuture.get()).thenReturn(expectedResults); + when(successfulFuture.cancel(anyBoolean())).thenReturn(false); + when(blockedFuture.cancel(anyBoolean())).thenReturn(true); + when(successfulFuture.isCancelled()).thenReturn(false); + when(blockedFuture.isCancelled()).thenReturn(true); + + GetRecordsResult actualResults = strategy.getRecords(10); + + verify(completionService, times(2)).submit(any()); + verify(completionService, times(2)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS)); + verify(successfulFuture).get(); + verify(blockedFuture, never()).get(); + verify(successfulFuture).cancel(eq(true)); + verify(blockedFuture).cancel(eq(true)); + verify(successfulFuture).isCancelled(); + verify(blockedFuture).isCancelled(); + verify(completionService).take(); + + assertThat(actualResults, equalTo(expectedResults)); + } + + @Test(expected = IllegalStateException.class) + public void testStrategyIsShutdown() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID); + + when(executorService.isShutdown()).thenReturn(true); + + strategy.getRecords(10); + } + + @Test + public void testPoolOutOfResources() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID); + + when(executorService.isShutdown()).thenReturn(false); + when(completionService.submit(any())).thenReturn(blockedFuture).thenThrow(new RejectedExecutionException("Rejected!")).thenReturn(successfulFuture); + when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(null).thenReturn(successfulFuture); + when(successfulFuture.get()).thenReturn(expectedResults); + when(successfulFuture.cancel(anyBoolean())).thenReturn(false); + when(blockedFuture.cancel(anyBoolean())).thenReturn(true); + when(successfulFuture.isCancelled()).thenReturn(false); + when(blockedFuture.isCancelled()).thenReturn(true); + + GetRecordsResult actualResult = strategy.getRecords(10); + + verify(completionService, times(3)).submit(any()); + verify(completionService, times(3)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS)); + verify(successfulFuture).cancel(eq(true)); + verify(blockedFuture).cancel(eq(true)); + verify(successfulFuture).isCancelled(); + verify(blockedFuture).isCancelled(); + verify(completionService).take(); + + assertThat(actualResult, equalTo(expectedResults)); + } + +} diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStatesTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStatesTest.java index 31272379..307aa6b8 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStatesTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStatesTest.java @@ -17,6 +17,7 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; import static com.amazonaws.services.kinesis.clientlibrary.lib.worker.ConsumerStates.ConsumerState; import static com.amazonaws.services.kinesis.clientlibrary.lib.worker.ConsumerStates.ShardConsumerState; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.mockito.Mockito.never; @@ -25,6 +26,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.lang.reflect.Field; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -152,7 +154,10 @@ public class ConsumerStatesTest { } @Test - public void processingStateTest() { + public void processingStateTestSynchronous() { + when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.empty()); + when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.empty()); + ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState(); ITask task = state.createTask(consumer); @@ -163,6 +168,38 @@ public class ConsumerStatesTest { assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher))); assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig))); assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis))); + assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(SynchronousGetRecordsRetrievalStrategy.class) )); + + assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState())); + + assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE), + equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState())); + assertThat(state.shutdownTransition(ShutdownReason.TERMINATE), + equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState())); + assertThat(state.shutdownTransition(ShutdownReason.REQUESTED), + equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState())); + + assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING)); + assertThat(state.getTaskType(), equalTo(TaskType.PROCESS)); + + } + + @Test + public void processingStateTestAsynchronous() { + when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.of(1)); + when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.of(2)); + + ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState(); + ITask task = state.createTask(consumer); + + assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); + assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor))); + assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer", + equalTo(recordProcessorCheckpointer))); + assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher))); + assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig))); + assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis))); + assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(AsynchronousGetRecordsRetrievalStrategy.class) )); assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState())); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java index dd56a256..2597d76b 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java @@ -117,6 +117,7 @@ public class KinesisDataFetcherTest { ICheckpoint checkpoint = mock(ICheckpoint.class); KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO); + GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher); String iteratorA = "foo"; String iteratorB = "bar"; @@ -138,10 +139,10 @@ public class KinesisDataFetcherTest { fetcher.initialize(seqA, null); fetcher.advanceIteratorTo(seqA, null); - Assert.assertEquals(recordsA, fetcher.getRecords(MAX_RECORDS).getRecords()); + Assert.assertEquals(recordsA, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords()); fetcher.advanceIteratorTo(seqB, null); - Assert.assertEquals(recordsB, fetcher.getRecords(MAX_RECORDS).getRecords()); + Assert.assertEquals(recordsB, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords()); } @Test @@ -181,8 +182,9 @@ public class KinesisDataFetcherTest { // Create data fectcher and initialize it with latest type checkpoint KinesisDataFetcher dataFetcher = new KinesisDataFetcher(mockProxy, SHARD_INFO); dataFetcher.initialize(SentinelCheckpoint.LATEST.toString(), INITIAL_POSITION_LATEST); + GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(dataFetcher); // Call getRecords of dataFetcher which will throw an exception - dataFetcher.getRecords(maxRecords); + getRecordsRetrievalStrategy.getRecords(maxRecords); // Test shard has reached the end Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached()); @@ -206,8 +208,9 @@ public class KinesisDataFetcherTest { when(checkpoint.getCheckpoint(SHARD_ID)).thenReturn(new ExtendedSequenceNumber(seqNo)); KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO); + GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher); fetcher.initialize(seqNo, initialPositionInStream); - List actualRecords = fetcher.getRecords(MAX_RECORDS).getRecords(); + List actualRecords = getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords(); Assert.assertEquals(expectedRecords, actualRecords); } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java index e95aef50..0c47e9b9 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java @@ -19,7 +19,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.anyInt; -import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.never; @@ -76,6 +76,8 @@ public class ProcessTaskTest { private @Mock RecordProcessorCheckpointer mockCheckpointer; @Mock private ThrottlingReporter throttlingReporter; + @Mock + private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy; private List processedRecords; private ExtendedSequenceNumber newLargestPermittedCheckpointValue; @@ -94,19 +96,20 @@ public class ProcessTaskTest { final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null); processTask = new ProcessTask( shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis, - KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter); + KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy); } @Test public void testProcessTaskWithProvisionedThroughputExceededException() { // Set data fetcher to throw exception doReturn(false).when(mockDataFetcher).isShardEndReached(); - doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockDataFetcher) + doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy) .getRecords(maxRecords); TaskResult result = processTask.call(); verify(throttlingReporter).throttled(); verify(throttlingReporter, never()).success(); + verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords)); assertTrue("Result should contain ProvisionedThroughputExceededException", result.getException() instanceof ProvisionedThroughputExceededException); } @@ -114,9 +117,10 @@ public class ProcessTaskTest { @Test public void testProcessTaskWithNonExistentStream() { // Data fetcher returns a null Result when the stream does not exist - doReturn(null).when(mockDataFetcher).getRecords(maxRecords); + doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords); TaskResult result = processTask.call(); + verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords)); assertNull("Task should not throw an exception", result.getException()); } @@ -300,14 +304,14 @@ public class ProcessTaskTest { private void testWithRecords(List records, ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber largestPermittedCheckpointValue) { - when(mockDataFetcher.getRecords(anyInt())).thenReturn( + when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn( new GetRecordsResult().withRecords(records)); when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue); when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue); processTask.call(); verify(throttlingReporter).success(); verify(throttlingReporter, never()).throttled(); - + verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt()); ArgumentCaptor priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); verify(mockRecordProcessor).processRecords(priCaptor.capture()); processedRecords = priCaptor.getValue().getRecords();