From e0ae95dd095523ac3f810757f2e4edebf0cccf8e Mon Sep 17 00:00:00 2001 From: Sahil Palvia Date: Thu, 31 Aug 2017 17:01:48 -0700 Subject: [PATCH 1/2] Adding unit tests, using take to clear the cancelled calls. --- ...synchronousGetRecordsRetrivalStrategy.java | 8 +- ...hronousGetRecordsRetrivalStrategyTest.java | 145 ++++++++++++++++++ 2 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategyTest.java diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategy.java index 415ec6d1..543bec8d 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategy.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategy.java @@ -87,7 +87,13 @@ public class AsynchronousGetRecordsRetrivalStrategy implements GetRecordsRetriva break; } } - futures.stream().peek(f -> f.cancel(true)).filter(Future::isCancelled).forEach(f -> completionService.poll()); + futures.stream().peek(f -> f.cancel(true)).filter(Future::isCancelled).forEach(f -> { + try { + completionService.take(); + } catch (InterruptedException e) { + log.error("Exception thrown while trying to empty the threadpool."); + } + }); return result; } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategyTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategyTest.java new file mode 100644 index 00000000..f3399121 --- /dev/null +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategyTest.java @@ -0,0 +1,145 @@ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyObject; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Spy; +import org.mockito.runners.MockitoJUnitRunner; + +import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; +import com.amazonaws.services.kinesis.model.GetRecordsResult; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +/** + * + */ +@RunWith(MockitoJUnitRunner.class) +public class AsynchronousGetRecordsRetrivalStrategyTest { + private static final int CORE_POOL_SIZE = 1; + private static final int MAX_POOL_SIZE = 2; + private static final int TIME_TO_LIVE = 5; + private static final int RETRY_GET_RECORDS_IN_SECONDS = 2; + private static final int SLEEP_GET_RECORDS_IN_SECONDS = 10; + + @Mock + private IKinesisProxy mockKinesisProxy; + + @Mock + private ShardInfo mockShardInfo; + + private AsynchronousGetRecordsRetrivalStrategy getRecordsRetrivalStrategy; + private KinesisDataFetcher dataFetcher; + private GetRecordsResult result; + private ExecutorService executorService; + private RejectedExecutionHandler rejectedExecutionHandler; + private int numberOfRecords = 10; + private CompletionService completionService; + + @Before + public void setup() { + dataFetcher = spy(new KinesisDataFetcherForTests(mockKinesisProxy, mockShardInfo)); + rejectedExecutionHandler = spy(new ThreadPoolExecutor.AbortPolicy()); + executorService = spy(new ThreadPoolExecutor( + CORE_POOL_SIZE, + MAX_POOL_SIZE, + TIME_TO_LIVE, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>(1), + new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(), + rejectedExecutionHandler)); + getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrivalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS); + completionService = spy(getRecordsRetrivalStrategy.getCompletionService()); + result = null; + } + + @Test + public void oneRequestMultithreadTest() { + GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); + verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords)); + verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any()); + assertNull(getRecordsResult); + } + + @Test + public void multiRequestTest() { + result = mock(GetRecordsResult.class); + + GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); + verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(numberOfRecords); + verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any()); + assertEquals(result, getRecordsResult); + + result = null; + getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); + assertNull(getRecordsResult); + } + + /*@Test + public void testInterrupted() throws InterruptedException, ExecutionException { + Future mockFuture = mock(Future.class); + System.out.println(completionService); + when(completionService.submit(any())).thenReturn(mockFuture); + when(completionService.poll()).thenReturn(mockFuture); + doThrow(InterruptedException.class).when(mockFuture).get(); + GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); + verify(mockFuture).get(); + assertNull(getRecordsResult); + }*/ + + private int getLeastNumberOfCalls() { + int leastNumberOfCalls = 0; + for (int i = MAX_POOL_SIZE; i > 0; i--) { + if (i * RETRY_GET_RECORDS_IN_SECONDS <= SLEEP_GET_RECORDS_IN_SECONDS) { + leastNumberOfCalls = i; + break; + } + } + return leastNumberOfCalls; + } + + @After + public void shutdown() { + getRecordsRetrivalStrategy.shutdown(); + verify(executorService).shutdownNow(); + } + + private class KinesisDataFetcherForTests extends KinesisDataFetcher { + public KinesisDataFetcherForTests(final IKinesisProxy kinesisProxy, final ShardInfo shardInfo) { + super(kinesisProxy, shardInfo); + } + + @Override + public GetRecordsResult getRecords(final int maxRecords) { + try { + Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000); + } catch (InterruptedException e) { + // Do nothing + } + return result; + } + } +} From 7472cec60cabf0539eee90728ace71c0df474c4e Mon Sep 17 00:00:00 2001 From: "Pfifer, Justin" Date: Fri, 1 Sep 2017 09:17:59 -0700 Subject: [PATCH 2/2] Fixed a Spelling Error, and Slight Refactor for Tests * Renamed the retrieval strategy classes to fix a spelling error. * Modified the strategy interface to support shutdown, and determination of whether a strategy has been shutdown. * Moved the existing tests for the async strategy to an integration test. * Modified the async strategy to allow injection of more state components * Modified the async strategy to throw an exception if an attempt is made to use it after shutdown. cr https://code.amazon.com/reviews/CR-590341 --- ...nchronousGetRecordsRetrievalStrategy.java} | 65 +++++---- .../worker/GetRecordsRetrievalStrategy.java | 33 +++++ .../worker/GetRecordsRetrivalStrategy.java | 10 -- .../clientlibrary/lib/worker/ProcessTask.java | 10 +- ...nchronousGetRecordsRetrievalStrategy.java} | 14 +- ...ordsRetrievalStrategyIntegrationTest.java} | 76 +++++----- ...ronousGetRecordsRetrievalStrategyTest.java | 137 ++++++++++++++++++ .../lib/worker/KinesisDataFetcherTest.java | 14 +- .../lib/worker/ProcessTaskTest.java | 16 +- 9 files changed, 281 insertions(+), 94 deletions(-) rename src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/{AsynchronousGetRecordsRetrivalStrategy.java => AsynchronousGetRecordsRetrievalStrategy.java} (54%) create mode 100644 src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java delete mode 100644 src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrivalStrategy.java rename src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/{SynchronousGetRecordsRetrivalStrategy.java => SynchronousGetRecordsRetrievalStrategy.java} (55%) rename src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/{AsynchronousGetRecordsRetrivalStrategyTest.java => AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java} (84%) create mode 100644 src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java similarity index 54% rename from src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategy.java rename to src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java index 543bec8d..db5ea042 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategy.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java @@ -2,7 +2,6 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; import java.util.HashSet; import java.util.Set; -import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorCompletionService; @@ -16,7 +15,6 @@ import java.util.concurrent.TimeUnit; import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import lombok.Getter; import lombok.NonNull; import lombok.extern.apachecommons.CommonsLog; @@ -24,58 +22,53 @@ import lombok.extern.apachecommons.CommonsLog; * */ @CommonsLog -public class AsynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrivalStrategy { +public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy { private static final int TIME_TO_KEEP_ALIVE = 5; private static final int CORE_THREAD_POOL_COUNT = 1; private final KinesisDataFetcher dataFetcher; private final ExecutorService executorService; private final int retryGetRecordsInSeconds; - @Getter - private final CompletionService completionService; + private final String shardId; + final CompletionService completionService; - public AsynchronousGetRecordsRetrivalStrategy(@NonNull final KinesisDataFetcher dataFetcher, - final int retryGetRecordsInSeconds, - final int maxGetRecordsThreadPool) { - this (dataFetcher, - new ThreadPoolExecutor( - CORE_THREAD_POOL_COUNT, - maxGetRecordsThreadPool, - TIME_TO_KEEP_ALIVE, - TimeUnit.SECONDS, - new LinkedBlockingQueue<>(1), - new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(), - new ThreadPoolExecutor.AbortPolicy()), - retryGetRecordsInSeconds); + public AsynchronousGetRecordsRetrievalStrategy(@NonNull final KinesisDataFetcher dataFetcher, + final int retryGetRecordsInSeconds, final int maxGetRecordsThreadPool, String shardId) { + this(dataFetcher, buildExector(maxGetRecordsThreadPool, shardId), retryGetRecordsInSeconds, shardId); } - public AsynchronousGetRecordsRetrivalStrategy(final KinesisDataFetcher dataFetcher, - final ExecutorService executorService, - final int retryGetRecordsInSeconds) { + public AsynchronousGetRecordsRetrievalStrategy(final KinesisDataFetcher dataFetcher, + final ExecutorService executorService, final int retryGetRecordsInSeconds, String shardId) { + this(dataFetcher, executorService, retryGetRecordsInSeconds, new ExecutorCompletionService<>(executorService), + shardId); + } + + AsynchronousGetRecordsRetrievalStrategy(KinesisDataFetcher dataFetcher, ExecutorService executorService, + int retryGetRecordsInSeconds, CompletionService completionService, String shardId) { this.dataFetcher = dataFetcher; this.executorService = executorService; this.retryGetRecordsInSeconds = retryGetRecordsInSeconds; - this.completionService = new ExecutorCompletionService<>(executorService); + this.completionService = completionService; + this.shardId = shardId; } @Override public GetRecordsResult getRecords(final int maxRecords) { + if (executorService.isShutdown()) { + throw new IllegalStateException("Strategy has been shutdown"); + } GetRecordsResult result = null; Set> futures = new HashSet<>(); while (true) { try { - futures.add(completionService.submit(new Callable() { - @Override - public GetRecordsResult call() throws Exception { - return dataFetcher.getRecords(maxRecords); - } - })); + futures.add(completionService.submit(() -> dataFetcher.getRecords(maxRecords))); } catch (RejectedExecutionException e) { log.warn("Out of resources, unable to start additional requests."); } try { - Future resultFuture = completionService.poll(retryGetRecordsInSeconds, TimeUnit.SECONDS); + Future resultFuture = completionService.poll(retryGetRecordsInSeconds, + TimeUnit.SECONDS); if (resultFuture != null) { result = resultFuture.get(); break; @@ -97,7 +90,21 @@ public class AsynchronousGetRecordsRetrivalStrategy implements GetRecordsRetriva return result; } + @Override public void shutdown() { executorService.shutdownNow(); } + + @Override + public boolean isShutdown() { + return executorService.isShutdown(); + } + + private static ExecutorService buildExector(int maxGetRecordsThreadPool, String shardId) { + String threadNameFormat = "get-records-worker-" + shardId + "-%d"; + return new ThreadPoolExecutor(CORE_THREAD_POOL_COUNT, maxGetRecordsThreadPool, TIME_TO_KEEP_ALIVE, + TimeUnit.SECONDS, new LinkedBlockingQueue<>(1), + new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadNameFormat).build(), + new ThreadPoolExecutor.AbortPolicy()); + } } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java new file mode 100644 index 00000000..a391ac59 --- /dev/null +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java @@ -0,0 +1,33 @@ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import com.amazonaws.services.kinesis.model.GetRecordsResult; + +/** + * Represents a strategy to retrieve records from Kinesis. Allows for variations on how records are retrieved from + * Kinesis. + */ +public interface GetRecordsRetrievalStrategy { + /** + * Gets a set of records from Kinesis. + * + * @param maxRecords + * passed to Kinesis, and can be used to restrict the number of records returned from Kinesis. + * @return the resulting records. + * @throws IllegalStateException + * if the strategy has been shutdown. + */ + GetRecordsResult getRecords(int maxRecords); + + /** + * Releases any resources used by the strategy. Once the strategy is shutdown it is no longer safe to call + * {@link #getRecords(int)}. + */ + void shutdown(); + + /** + * Returns whether this strategy has been shutdown. + * + * @return true if the strategy has been shutdown, false otherwise. + */ + boolean isShutdown(); +} diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrivalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrivalStrategy.java deleted file mode 100644 index ed3d3f93..00000000 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrivalStrategy.java +++ /dev/null @@ -1,10 +0,0 @@ -package com.amazonaws.services.kinesis.clientlibrary.lib.worker; - -import com.amazonaws.services.kinesis.model.GetRecordsResult; - -/** - * - */ -public interface GetRecordsRetrivalStrategy { - GetRecordsResult getRecords(int maxRecords); -} diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTask.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTask.java index 42e36afa..ddb582dc 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTask.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTask.java @@ -62,7 +62,7 @@ class ProcessTask implements ITask { private final Shard shard; private final ThrottlingReporter throttlingReporter; - private final GetRecordsRetrivalStrategy getRecordsRetrivalStrategy; + private final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy; /** * @param shardInfo @@ -83,7 +83,7 @@ class ProcessTask implements ITask { long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) { this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis, skipShardSyncAtWorkerInitializationIfLeasesExist, - new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), new SynchronousGetRecordsRetrivalStrategy(dataFetcher)); + new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), new SynchronousGetRecordsRetrievalStrategy(dataFetcher)); } /** @@ -105,7 +105,7 @@ class ProcessTask implements ITask { public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist, - ThrottlingReporter throttlingReporter, GetRecordsRetrivalStrategy getRecordsRetrivalStrategy) { + ThrottlingReporter throttlingReporter, GetRecordsRetrievalStrategy getRecordsRetrievalStrategy) { super(); this.shardInfo = shardInfo; this.recordProcessor = recordProcessor; @@ -115,7 +115,7 @@ class ProcessTask implements ITask { this.backoffTimeMillis = backoffTimeMillis; this.throttlingReporter = throttlingReporter; IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy(); - this.getRecordsRetrivalStrategy = getRecordsRetrivalStrategy; + this.getRecordsRetrievalStrategy = getRecordsRetrievalStrategy; // If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for // this ProcessTask. In this case, duplicate KPL user records in the event of resharding will // not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if @@ -371,7 +371,7 @@ class ProcessTask implements ITask { * @return list of data records from Kinesis */ private GetRecordsResult getRecordsResultAndRecordMillisBehindLatest() { - final GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(streamConfig.getMaxRecords()); + final GetRecordsResult getRecordsResult = getRecordsRetrievalStrategy.getRecords(streamConfig.getMaxRecords()); if (getRecordsResult == null) { // Stream no longer exists diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrivalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java similarity index 55% rename from src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrivalStrategy.java rename to src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java index 567eef0b..77a60448 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrivalStrategy.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java @@ -8,7 +8,7 @@ import lombok.NonNull; * */ @Data -public class SynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrivalStrategy { +public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy { @NonNull private final KinesisDataFetcher dataFetcher; @@ -16,4 +16,16 @@ public class SynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrival public GetRecordsResult getRecords(final int maxRecords) { return dataFetcher.getRecords(maxRecords); } + + @Override + public void shutdown() { + // + // Does nothing as this retriever doesn't manage any resources + // + } + + @Override + public boolean isShutdown() { + return false; + } } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategyTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java similarity index 84% rename from src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategyTest.java rename to src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java index f3399121..f8411651 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrivalStrategyTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java @@ -1,17 +1,27 @@ +/* + * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyObject; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.atLeast; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; +import com.amazonaws.services.kinesis.model.GetRecordsResult; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.mockito.Mock; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; @@ -22,23 +32,19 @@ import java.util.concurrent.RejectedExecutionHandler; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Spy; -import org.mockito.runners.MockitoJUnitRunner; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; -import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; -import com.amazonaws.services.kinesis.model.GetRecordsResult; -import com.google.common.util.concurrent.ThreadFactoryBuilder; +public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest { -/** - * - */ -@RunWith(MockitoJUnitRunner.class) -public class AsynchronousGetRecordsRetrivalStrategyTest { private static final int CORE_POOL_SIZE = 1; private static final int MAX_POOL_SIZE = 2; private static final int TIME_TO_LIVE = 5; @@ -51,7 +57,7 @@ public class AsynchronousGetRecordsRetrivalStrategyTest { @Mock private ShardInfo mockShardInfo; - private AsynchronousGetRecordsRetrivalStrategy getRecordsRetrivalStrategy; + private AsynchronousGetRecordsRetrievalStrategy getRecordsRetrivalStrategy; private KinesisDataFetcher dataFetcher; private GetRecordsResult result; private ExecutorService executorService; @@ -71,8 +77,8 @@ public class AsynchronousGetRecordsRetrivalStrategyTest { new LinkedBlockingQueue<>(1), new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(), rejectedExecutionHandler)); - getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrivalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS); - completionService = spy(getRecordsRetrivalStrategy.getCompletionService()); + getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS, "shardId-0001"); + completionService = spy(getRecordsRetrivalStrategy.completionService); result = null; } @@ -97,18 +103,19 @@ public class AsynchronousGetRecordsRetrivalStrategyTest { getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); assertNull(getRecordsResult); } - - /*@Test + + @Test + @Ignore public void testInterrupted() throws InterruptedException, ExecutionException { + Future mockFuture = mock(Future.class); - System.out.println(completionService); when(completionService.submit(any())).thenReturn(mockFuture); when(completionService.poll()).thenReturn(mockFuture); doThrow(InterruptedException.class).when(mockFuture).get(); GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); verify(mockFuture).get(); assertNull(getRecordsResult); - }*/ + } private int getLeastNumberOfCalls() { int leastNumberOfCalls = 0; @@ -142,4 +149,5 @@ public class AsynchronousGetRecordsRetrivalStrategyTest { return result; } } + } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java new file mode 100644 index 00000000..dfba0351 --- /dev/null +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java @@ -0,0 +1,137 @@ +package com.amazonaws.services.kinesis.clientlibrary.lib.worker; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import com.amazonaws.services.kinesis.model.GetRecordsResult; + +/** + * + */ +@RunWith(MockitoJUnitRunner.class) +public class AsynchronousGetRecordsRetrievalStrategyTest { + + private static final long RETRY_GET_RECORDS_IN_SECONDS = 5; + private static final String SHARD_ID = "ShardId-0001"; + @Mock + private KinesisDataFetcher dataFetcher; + @Mock + private ExecutorService executorService; + @Mock + private CompletionService completionService; + @Mock + private Future successfulFuture; + @Mock + private Future blockedFuture; + @Mock + private GetRecordsResult expectedResults; + + @Test + public void testSingleSuccessfulRequestFuture() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID); + + when(executorService.isShutdown()).thenReturn(false); + when(completionService.submit(any())).thenReturn(successfulFuture); + when(completionService.poll(anyLong(), any())).thenReturn(successfulFuture); + when(successfulFuture.get()).thenReturn(expectedResults); + + GetRecordsResult result = strategy.getRecords(10); + + verify(executorService).isShutdown(); + verify(completionService).submit(any()); + verify(completionService).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS)); + verify(successfulFuture).get(); + verify(successfulFuture).cancel(eq(true)); + verify(successfulFuture).isCancelled(); + verify(completionService, never()).take(); + + assertThat(result, equalTo(expectedResults)); + } + + @Test + public void testBlockedAndSuccessfulFuture() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID); + + when(executorService.isShutdown()).thenReturn(false); + when(completionService.submit(any())).thenReturn(blockedFuture).thenReturn(successfulFuture); + when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture); + when(successfulFuture.get()).thenReturn(expectedResults); + when(successfulFuture.cancel(anyBoolean())).thenReturn(false); + when(blockedFuture.cancel(anyBoolean())).thenReturn(true); + when(successfulFuture.isCancelled()).thenReturn(false); + when(blockedFuture.isCancelled()).thenReturn(true); + + GetRecordsResult actualResults = strategy.getRecords(10); + + verify(completionService, times(2)).submit(any()); + verify(completionService, times(2)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS)); + verify(successfulFuture).get(); + verify(blockedFuture, never()).get(); + verify(successfulFuture).cancel(eq(true)); + verify(blockedFuture).cancel(eq(true)); + verify(successfulFuture).isCancelled(); + verify(blockedFuture).isCancelled(); + verify(completionService).take(); + + assertThat(actualResults, equalTo(expectedResults)); + } + + @Test(expected = IllegalStateException.class) + public void testStrategyIsShutdown() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID); + + when(executorService.isShutdown()).thenReturn(true); + + strategy.getRecords(10); + } + + @Test + public void testPoolOutOfResources() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID); + + when(executorService.isShutdown()).thenReturn(false); + when(completionService.submit(any())).thenReturn(blockedFuture).thenThrow(new RejectedExecutionException("Rejected!")).thenReturn(successfulFuture); + when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(null).thenReturn(successfulFuture); + when(successfulFuture.get()).thenReturn(expectedResults); + when(successfulFuture.cancel(anyBoolean())).thenReturn(false); + when(blockedFuture.cancel(anyBoolean())).thenReturn(true); + when(successfulFuture.isCancelled()).thenReturn(false); + when(blockedFuture.isCancelled()).thenReturn(true); + + GetRecordsResult actualResult = strategy.getRecords(10); + + verify(completionService, times(3)).submit(any()); + verify(completionService, times(3)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS)); + verify(successfulFuture).cancel(eq(true)); + verify(blockedFuture).cancel(eq(true)); + verify(successfulFuture).isCancelled(); + verify(blockedFuture).isCancelled(); + verify(completionService).take(); + + assertThat(actualResult, equalTo(expectedResults)); + } + +} diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java index 01aaf3bb..2597d76b 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java @@ -117,7 +117,7 @@ public class KinesisDataFetcherTest { ICheckpoint checkpoint = mock(ICheckpoint.class); KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO); - GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(fetcher); + GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher); String iteratorA = "foo"; String iteratorB = "bar"; @@ -139,10 +139,10 @@ public class KinesisDataFetcherTest { fetcher.initialize(seqA, null); fetcher.advanceIteratorTo(seqA, null); - Assert.assertEquals(recordsA, getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords()); + Assert.assertEquals(recordsA, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords()); fetcher.advanceIteratorTo(seqB, null); - Assert.assertEquals(recordsB, getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords()); + Assert.assertEquals(recordsB, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords()); } @Test @@ -182,9 +182,9 @@ public class KinesisDataFetcherTest { // Create data fectcher and initialize it with latest type checkpoint KinesisDataFetcher dataFetcher = new KinesisDataFetcher(mockProxy, SHARD_INFO); dataFetcher.initialize(SentinelCheckpoint.LATEST.toString(), INITIAL_POSITION_LATEST); - GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(dataFetcher); + GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(dataFetcher); // Call getRecords of dataFetcher which will throw an exception - getRecordsRetrivalStrategy.getRecords(maxRecords); + getRecordsRetrievalStrategy.getRecords(maxRecords); // Test shard has reached the end Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached()); @@ -208,9 +208,9 @@ public class KinesisDataFetcherTest { when(checkpoint.getCheckpoint(SHARD_ID)).thenReturn(new ExtendedSequenceNumber(seqNo)); KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO); - GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(fetcher); + GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher); fetcher.initialize(seqNo, initialPositionInStream); - List actualRecords = getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords(); + List actualRecords = getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords(); Assert.assertEquals(expectedRecords, actualRecords); } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java index f704a7c4..0c47e9b9 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java @@ -77,7 +77,7 @@ public class ProcessTaskTest { @Mock private ThrottlingReporter throttlingReporter; @Mock - private GetRecordsRetrivalStrategy mockGetRecordsRetrivalStrategy; + private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy; private List processedRecords; private ExtendedSequenceNumber newLargestPermittedCheckpointValue; @@ -96,20 +96,20 @@ public class ProcessTaskTest { final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null); processTask = new ProcessTask( shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis, - KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrivalStrategy); + KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy); } @Test public void testProcessTaskWithProvisionedThroughputExceededException() { // Set data fetcher to throw exception doReturn(false).when(mockDataFetcher).isShardEndReached(); - doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrivalStrategy) + doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy) .getRecords(maxRecords); TaskResult result = processTask.call(); verify(throttlingReporter).throttled(); verify(throttlingReporter, never()).success(); - verify(mockGetRecordsRetrivalStrategy).getRecords(eq(maxRecords)); + verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords)); assertTrue("Result should contain ProvisionedThroughputExceededException", result.getException() instanceof ProvisionedThroughputExceededException); } @@ -117,10 +117,10 @@ public class ProcessTaskTest { @Test public void testProcessTaskWithNonExistentStream() { // Data fetcher returns a null Result when the stream does not exist - doReturn(null).when(mockGetRecordsRetrivalStrategy).getRecords(maxRecords); + doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords); TaskResult result = processTask.call(); - verify(mockGetRecordsRetrivalStrategy).getRecords(eq(maxRecords)); + verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords)); assertNull("Task should not throw an exception", result.getException()); } @@ -304,14 +304,14 @@ public class ProcessTaskTest { private void testWithRecords(List records, ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber largestPermittedCheckpointValue) { - when(mockGetRecordsRetrivalStrategy.getRecords(anyInt())).thenReturn( + when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn( new GetRecordsResult().withRecords(records)); when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue); when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue); processTask.call(); verify(throttlingReporter).success(); verify(throttlingReporter, never()).throttled(); - verify(mockGetRecordsRetrivalStrategy).getRecords(anyInt()); + verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt()); ArgumentCaptor priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); verify(mockRecordProcessor).processRecords(priCaptor.capture()); processedRecords = priCaptor.getValue().getRecords();