Merge branch 'getrecords-timeout' of ssh://git.amazon.com/pkg/AmazonKinesisClientGithubMirror into getrecords-retry

Changed ConsumerStateTest and ProcessTask due to the constructor change
This commit is contained in:
Wei 2017-09-01 10:34:53 -07:00
commit 4fe0e57998
11 changed files with 469 additions and 131 deletions

View file

@ -0,0 +1,110 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import lombok.NonNull;
import lombok.extern.apachecommons.CommonsLog;
/**
*
*/
@CommonsLog
public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy {
private static final int TIME_TO_KEEP_ALIVE = 5;
private static final int CORE_THREAD_POOL_COUNT = 1;
private final KinesisDataFetcher dataFetcher;
private final ExecutorService executorService;
private final int retryGetRecordsInSeconds;
private final String shardId;
final CompletionService<GetRecordsResult> completionService;
public AsynchronousGetRecordsRetrievalStrategy(@NonNull final KinesisDataFetcher dataFetcher,
final int retryGetRecordsInSeconds, final int maxGetRecordsThreadPool, String shardId) {
this(dataFetcher, buildExector(maxGetRecordsThreadPool, shardId), retryGetRecordsInSeconds, shardId);
}
public AsynchronousGetRecordsRetrievalStrategy(final KinesisDataFetcher dataFetcher,
final ExecutorService executorService, final int retryGetRecordsInSeconds, String shardId) {
this(dataFetcher, executorService, retryGetRecordsInSeconds, new ExecutorCompletionService<>(executorService),
shardId);
}
AsynchronousGetRecordsRetrievalStrategy(KinesisDataFetcher dataFetcher, ExecutorService executorService,
int retryGetRecordsInSeconds, CompletionService<GetRecordsResult> completionService, String shardId) {
this.dataFetcher = dataFetcher;
this.executorService = executorService;
this.retryGetRecordsInSeconds = retryGetRecordsInSeconds;
this.completionService = completionService;
this.shardId = shardId;
}
@Override
public GetRecordsResult getRecords(final int maxRecords) {
if (executorService.isShutdown()) {
throw new IllegalStateException("Strategy has been shutdown");
}
GetRecordsResult result = null;
Set<Future<GetRecordsResult>> futures = new HashSet<>();
while (true) {
try {
futures.add(completionService.submit(() -> dataFetcher.getRecords(maxRecords)));
} catch (RejectedExecutionException e) {
log.warn("Out of resources, unable to start additional requests.");
}
try {
Future<GetRecordsResult> resultFuture = completionService.poll(retryGetRecordsInSeconds,
TimeUnit.SECONDS);
if (resultFuture != null) {
result = resultFuture.get();
break;
}
} catch (ExecutionException e) {
log.error("ExecutionException thrown while trying to get records", e);
} catch (InterruptedException e) {
log.error("Thread was interrupted", e);
break;
}
}
futures.stream().peek(f -> f.cancel(true)).filter(Future::isCancelled).forEach(f -> {
try {
completionService.take();
} catch (InterruptedException e) {
log.error("Exception thrown while trying to empty the threadpool.");
}
});
return result;
}
@Override
public void shutdown() {
executorService.shutdownNow();
}
@Override
public boolean isShutdown() {
return executorService.isShutdown();
}
private static ExecutorService buildExector(int maxGetRecordsThreadPool, String shardId) {
String threadNameFormat = "get-records-worker-" + shardId + "-%d";
return new ThreadPoolExecutor(CORE_THREAD_POOL_COUNT, maxGetRecordsThreadPool, TIME_TO_KEEP_ALIVE,
TimeUnit.SECONDS, new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadNameFormat).build(),
new ThreadPoolExecutor.AbortPolicy());
}
}

View file

@ -1,97 +0,0 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.apachecommons.CommonsLog;
/**
*
*/
@CommonsLog
public class AsynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrivalStrategy {
private static final int TIME_TO_KEEP_ALIVE = 5;
private static final int CORE_THREAD_POOL_COUNT = 1;
private final KinesisDataFetcher dataFetcher;
private final ExecutorService executorService;
private final int retryGetRecordsInSeconds;
@Getter
private final CompletionService<GetRecordsResult> completionService;
public AsynchronousGetRecordsRetrivalStrategy(@NonNull final KinesisDataFetcher dataFetcher,
final int retryGetRecordsInSeconds,
final int maxGetRecordsThreadPool) {
this (dataFetcher,
new ThreadPoolExecutor(
CORE_THREAD_POOL_COUNT,
maxGetRecordsThreadPool,
TIME_TO_KEEP_ALIVE,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(),
new ThreadPoolExecutor.AbortPolicy()),
retryGetRecordsInSeconds);
}
public AsynchronousGetRecordsRetrivalStrategy(final KinesisDataFetcher dataFetcher,
final ExecutorService executorService,
final int retryGetRecordsInSeconds) {
this.dataFetcher = dataFetcher;
this.executorService = executorService;
this.retryGetRecordsInSeconds = retryGetRecordsInSeconds;
this.completionService = new ExecutorCompletionService<>(executorService);
}
@Override
public GetRecordsResult getRecords(final int maxRecords) {
GetRecordsResult result = null;
Set<Future<GetRecordsResult>> futures = new HashSet<>();
while (true) {
try {
futures.add(completionService.submit(new Callable<GetRecordsResult>() {
@Override
public GetRecordsResult call() throws Exception {
return dataFetcher.getRecords(maxRecords);
}
}));
} catch (RejectedExecutionException e) {
log.warn("Out of resources, unable to start additional requests.");
}
try {
Future<GetRecordsResult> resultFuture = completionService.poll(retryGetRecordsInSeconds, TimeUnit.SECONDS);
if (resultFuture != null) {
result = resultFuture.get();
break;
}
} catch (ExecutionException e) {
log.error("ExecutionException thrown while trying to get records", e);
} catch (InterruptedException e) {
log.error("Thread was interrupted", e);
break;
}
}
futures.stream().peek(f -> f.cancel(true)).filter(Future::isCancelled).forEach(f -> completionService.poll());
return result;
}
public void shutdown() {
executorService.shutdownNow();
}
}

View file

@ -0,0 +1,33 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
* Represents a strategy to retrieve records from Kinesis. Allows for variations on how records are retrieved from
* Kinesis.
*/
public interface GetRecordsRetrievalStrategy {
/**
* Gets a set of records from Kinesis.
*
* @param maxRecords
* passed to Kinesis, and can be used to restrict the number of records returned from Kinesis.
* @return the resulting records.
* @throws IllegalStateException
* if the strategy has been shutdown.
*/
GetRecordsResult getRecords(int maxRecords);
/**
* Releases any resources used by the strategy. Once the strategy is shutdown it is no longer safe to call
* {@link #getRecords(int)}.
*/
void shutdown();
/**
* Returns whether this strategy has been shutdown.
*
* @return true if the strategy has been shutdown, false otherwise.
*/
boolean isShutdown();
}

View file

@ -1,10 +0,0 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
*
*/
public interface GetRecordsRetrivalStrategy {
GetRecordsResult getRecords(int maxRecords);
}

View file

@ -63,7 +63,7 @@ class ProcessTask implements ITask {
private final Shard shard; private final Shard shard;
private final ThrottlingReporter throttlingReporter; private final ThrottlingReporter throttlingReporter;
private final GetRecordsRetrivalStrategy getRecordsRetrivalStrategy; private final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
/** /**
* @param shardInfo * @param shardInfo
@ -84,7 +84,7 @@ class ProcessTask implements ITask {
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) { long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) {
this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis, this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist, skipShardSyncAtWorkerInitializationIfLeasesExist,
new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), new SynchronousGetRecordsRetrivalStrategy(dataFetcher)); new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), new SynchronousGetRecordsRetrievalStrategy(dataFetcher));
} }
/** /**
@ -112,7 +112,7 @@ class ProcessTask implements ITask {
this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis, this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist, skipShardSyncAtWorkerInitializationIfLeasesExist,
new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()),
new AsynchronousGetRecordsRetrivalStrategy(dataFetcher, retryGetRecordsInSeconds, maxGetRecordsThreadPool)); new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, retryGetRecordsInSeconds, maxGetRecordsThreadPool, shardInfo.getShardId()));
} }
/** /**
@ -134,7 +134,7 @@ class ProcessTask implements ITask {
public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor,
RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher,
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist, long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist,
ThrottlingReporter throttlingReporter, GetRecordsRetrivalStrategy getRecordsRetrivalStrategy) { ThrottlingReporter throttlingReporter, GetRecordsRetrievalStrategy getRecordsRetrievalStrategy) {
super(); super();
this.shardInfo = shardInfo; this.shardInfo = shardInfo;
this.recordProcessor = recordProcessor; this.recordProcessor = recordProcessor;
@ -144,7 +144,7 @@ class ProcessTask implements ITask {
this.backoffTimeMillis = backoffTimeMillis; this.backoffTimeMillis = backoffTimeMillis;
this.throttlingReporter = throttlingReporter; this.throttlingReporter = throttlingReporter;
IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy(); IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy();
this.getRecordsRetrivalStrategy = getRecordsRetrivalStrategy; this.getRecordsRetrievalStrategy = getRecordsRetrievalStrategy;
// If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for // If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for
// this ProcessTask. In this case, duplicate KPL user records in the event of resharding will // this ProcessTask. In this case, duplicate KPL user records in the event of resharding will
// not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if // not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if
@ -400,7 +400,7 @@ class ProcessTask implements ITask {
* @return list of data records from Kinesis * @return list of data records from Kinesis
*/ */
private GetRecordsResult getRecordsResultAndRecordMillisBehindLatest() { private GetRecordsResult getRecordsResultAndRecordMillisBehindLatest() {
final GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(streamConfig.getMaxRecords()); final GetRecordsResult getRecordsResult = getRecordsRetrievalStrategy.getRecords(streamConfig.getMaxRecords());
if (getRecordsResult == null) { if (getRecordsResult == null) {
// Stream no longer exists // Stream no longer exists

View file

@ -8,7 +8,7 @@ import lombok.NonNull;
* *
*/ */
@Data @Data
public class SynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrivalStrategy { public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy {
@NonNull @NonNull
private final KinesisDataFetcher dataFetcher; private final KinesisDataFetcher dataFetcher;
@ -16,4 +16,16 @@ public class SynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrival
public GetRecordsResult getRecords(final int maxRecords) { public GetRecordsResult getRecords(final int maxRecords) {
return dataFetcher.getRecords(maxRecords); return dataFetcher.getRecords(maxRecords);
} }
@Override
public void shutdown() {
//
// Does nothing as this retriever doesn't manage any resources
//
}
@Override
public boolean isShutdown() {
return false;
}
} }

View file

@ -0,0 +1,153 @@
/*
* Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.mockito.Mock;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
private static final int CORE_POOL_SIZE = 1;
private static final int MAX_POOL_SIZE = 2;
private static final int TIME_TO_LIVE = 5;
private static final int RETRY_GET_RECORDS_IN_SECONDS = 2;
private static final int SLEEP_GET_RECORDS_IN_SECONDS = 10;
@Mock
private IKinesisProxy mockKinesisProxy;
@Mock
private ShardInfo mockShardInfo;
private AsynchronousGetRecordsRetrievalStrategy getRecordsRetrivalStrategy;
private KinesisDataFetcher dataFetcher;
private GetRecordsResult result;
private ExecutorService executorService;
private RejectedExecutionHandler rejectedExecutionHandler;
private int numberOfRecords = 10;
private CompletionService<GetRecordsResult> completionService;
@Before
public void setup() {
dataFetcher = spy(new KinesisDataFetcherForTests(mockKinesisProxy, mockShardInfo));
rejectedExecutionHandler = spy(new ThreadPoolExecutor.AbortPolicy());
executorService = spy(new ThreadPoolExecutor(
CORE_POOL_SIZE,
MAX_POOL_SIZE,
TIME_TO_LIVE,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(),
rejectedExecutionHandler));
getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS, "shardId-0001");
completionService = spy(getRecordsRetrivalStrategy.completionService);
result = null;
}
@Test
public void oneRequestMultithreadTest() {
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords));
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
assertNull(getRecordsResult);
}
@Test
public void multiRequestTest() {
result = mock(GetRecordsResult.class);
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(numberOfRecords);
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
assertEquals(result, getRecordsResult);
result = null;
getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
assertNull(getRecordsResult);
}
@Test
@Ignore
public void testInterrupted() throws InterruptedException, ExecutionException {
Future<GetRecordsResult> mockFuture = mock(Future.class);
when(completionService.submit(any())).thenReturn(mockFuture);
when(completionService.poll()).thenReturn(mockFuture);
doThrow(InterruptedException.class).when(mockFuture).get();
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(mockFuture).get();
assertNull(getRecordsResult);
}
private int getLeastNumberOfCalls() {
int leastNumberOfCalls = 0;
for (int i = MAX_POOL_SIZE; i > 0; i--) {
if (i * RETRY_GET_RECORDS_IN_SECONDS <= SLEEP_GET_RECORDS_IN_SECONDS) {
leastNumberOfCalls = i;
break;
}
}
return leastNumberOfCalls;
}
@After
public void shutdown() {
getRecordsRetrivalStrategy.shutdown();
verify(executorService).shutdownNow();
}
private class KinesisDataFetcherForTests extends KinesisDataFetcher {
public KinesisDataFetcherForTests(final IKinesisProxy kinesisProxy, final ShardInfo shardInfo) {
super(kinesisProxy, shardInfo);
}
@Override
public GetRecordsResult getRecords(final int maxRecords) {
try {
Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000);
} catch (InterruptedException e) {
// Do nothing
}
return result;
}
}
}

View file

@ -0,0 +1,137 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
*
*/
@RunWith(MockitoJUnitRunner.class)
public class AsynchronousGetRecordsRetrievalStrategyTest {
private static final long RETRY_GET_RECORDS_IN_SECONDS = 5;
private static final String SHARD_ID = "ShardId-0001";
@Mock
private KinesisDataFetcher dataFetcher;
@Mock
private ExecutorService executorService;
@Mock
private CompletionService<GetRecordsResult> completionService;
@Mock
private Future<GetRecordsResult> successfulFuture;
@Mock
private Future<GetRecordsResult> blockedFuture;
@Mock
private GetRecordsResult expectedResults;
@Test
public void testSingleSuccessfulRequestFuture() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
GetRecordsResult result = strategy.getRecords(10);
verify(executorService).isShutdown();
verify(completionService).submit(any());
verify(completionService).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).get();
verify(successfulFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(completionService, never()).take();
assertThat(result, equalTo(expectedResults));
}
@Test
public void testBlockedAndSuccessfulFuture() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false);
when(blockedFuture.isCancelled()).thenReturn(true);
GetRecordsResult actualResults = strategy.getRecords(10);
verify(completionService, times(2)).submit(any());
verify(completionService, times(2)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).get();
verify(blockedFuture, never()).get();
verify(successfulFuture).cancel(eq(true));
verify(blockedFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(blockedFuture).isCancelled();
verify(completionService).take();
assertThat(actualResults, equalTo(expectedResults));
}
@Test(expected = IllegalStateException.class)
public void testStrategyIsShutdown() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(true);
strategy.getRecords(10);
}
@Test
public void testPoolOutOfResources() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenThrow(new RejectedExecutionException("Rejected!")).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false);
when(blockedFuture.isCancelled()).thenReturn(true);
GetRecordsResult actualResult = strategy.getRecords(10);
verify(completionService, times(3)).submit(any());
verify(completionService, times(3)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).cancel(eq(true));
verify(blockedFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(blockedFuture).isCancelled();
verify(completionService).take();
assertThat(actualResult, equalTo(expectedResults));
}
}

View file

@ -168,7 +168,7 @@ public class ConsumerStatesTest {
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher))); assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig))); assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis))); assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
assertThat(task, procTask(GetRecordsRetrivalStrategy.class, "getRecordsRetrivalStrategy", instanceOf(SynchronousGetRecordsRetrivalStrategy.class) )); assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrivalStrategy", instanceOf(SynchronousGetRecordsRetrievalStrategy.class) ));
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState())); assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));
@ -199,7 +199,7 @@ public class ConsumerStatesTest {
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher))); assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig))); assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis))); assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
assertThat(task, procTask(GetRecordsRetrivalStrategy.class, "getRecordsRetrivalStrategy", instanceOf(AsynchronousGetRecordsRetrivalStrategy.class) )); assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrivalStrategy", instanceOf(AsynchronousGetRecordsRetrievalStrategy.class) ));
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState())); assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));

View file

@ -117,7 +117,7 @@ public class KinesisDataFetcherTest {
ICheckpoint checkpoint = mock(ICheckpoint.class); ICheckpoint checkpoint = mock(ICheckpoint.class);
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO); KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO);
GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(fetcher); GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher);
String iteratorA = "foo"; String iteratorA = "foo";
String iteratorB = "bar"; String iteratorB = "bar";
@ -139,10 +139,10 @@ public class KinesisDataFetcherTest {
fetcher.initialize(seqA, null); fetcher.initialize(seqA, null);
fetcher.advanceIteratorTo(seqA, null); fetcher.advanceIteratorTo(seqA, null);
Assert.assertEquals(recordsA, getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords()); Assert.assertEquals(recordsA, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords());
fetcher.advanceIteratorTo(seqB, null); fetcher.advanceIteratorTo(seqB, null);
Assert.assertEquals(recordsB, getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords()); Assert.assertEquals(recordsB, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords());
} }
@Test @Test
@ -182,9 +182,9 @@ public class KinesisDataFetcherTest {
// Create data fectcher and initialize it with latest type checkpoint // Create data fectcher and initialize it with latest type checkpoint
KinesisDataFetcher dataFetcher = new KinesisDataFetcher(mockProxy, SHARD_INFO); KinesisDataFetcher dataFetcher = new KinesisDataFetcher(mockProxy, SHARD_INFO);
dataFetcher.initialize(SentinelCheckpoint.LATEST.toString(), INITIAL_POSITION_LATEST); dataFetcher.initialize(SentinelCheckpoint.LATEST.toString(), INITIAL_POSITION_LATEST);
GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(dataFetcher); GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(dataFetcher);
// Call getRecords of dataFetcher which will throw an exception // Call getRecords of dataFetcher which will throw an exception
getRecordsRetrivalStrategy.getRecords(maxRecords); getRecordsRetrievalStrategy.getRecords(maxRecords);
// Test shard has reached the end // Test shard has reached the end
Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached()); Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached());
@ -208,9 +208,9 @@ public class KinesisDataFetcherTest {
when(checkpoint.getCheckpoint(SHARD_ID)).thenReturn(new ExtendedSequenceNumber(seqNo)); when(checkpoint.getCheckpoint(SHARD_ID)).thenReturn(new ExtendedSequenceNumber(seqNo));
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO); KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO);
GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(fetcher); GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher);
fetcher.initialize(seqNo, initialPositionInStream); fetcher.initialize(seqNo, initialPositionInStream);
List<Record> actualRecords = getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords(); List<Record> actualRecords = getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords();
Assert.assertEquals(expectedRecords, actualRecords); Assert.assertEquals(expectedRecords, actualRecords);
} }

View file

@ -77,7 +77,7 @@ public class ProcessTaskTest {
@Mock @Mock
private ThrottlingReporter throttlingReporter; private ThrottlingReporter throttlingReporter;
@Mock @Mock
private GetRecordsRetrivalStrategy mockGetRecordsRetrivalStrategy; private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy;
private List<Record> processedRecords; private List<Record> processedRecords;
private ExtendedSequenceNumber newLargestPermittedCheckpointValue; private ExtendedSequenceNumber newLargestPermittedCheckpointValue;
@ -96,20 +96,20 @@ public class ProcessTaskTest {
final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null); final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null);
processTask = new ProcessTask( processTask = new ProcessTask(
shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis, shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrivalStrategy); KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy);
} }
@Test @Test
public void testProcessTaskWithProvisionedThroughputExceededException() { public void testProcessTaskWithProvisionedThroughputExceededException() {
// Set data fetcher to throw exception // Set data fetcher to throw exception
doReturn(false).when(mockDataFetcher).isShardEndReached(); doReturn(false).when(mockDataFetcher).isShardEndReached();
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrivalStrategy) doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy)
.getRecords(maxRecords); .getRecords(maxRecords);
TaskResult result = processTask.call(); TaskResult result = processTask.call();
verify(throttlingReporter).throttled(); verify(throttlingReporter).throttled();
verify(throttlingReporter, never()).success(); verify(throttlingReporter, never()).success();
verify(mockGetRecordsRetrivalStrategy).getRecords(eq(maxRecords)); verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
assertTrue("Result should contain ProvisionedThroughputExceededException", assertTrue("Result should contain ProvisionedThroughputExceededException",
result.getException() instanceof ProvisionedThroughputExceededException); result.getException() instanceof ProvisionedThroughputExceededException);
} }
@ -117,10 +117,10 @@ public class ProcessTaskTest {
@Test @Test
public void testProcessTaskWithNonExistentStream() { public void testProcessTaskWithNonExistentStream() {
// Data fetcher returns a null Result when the stream does not exist // Data fetcher returns a null Result when the stream does not exist
doReturn(null).when(mockGetRecordsRetrivalStrategy).getRecords(maxRecords); doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords);
TaskResult result = processTask.call(); TaskResult result = processTask.call();
verify(mockGetRecordsRetrivalStrategy).getRecords(eq(maxRecords)); verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
assertNull("Task should not throw an exception", result.getException()); assertNull("Task should not throw an exception", result.getException());
} }
@ -304,14 +304,14 @@ public class ProcessTaskTest {
private void testWithRecords(List<Record> records, private void testWithRecords(List<Record> records,
ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber lastCheckpointValue,
ExtendedSequenceNumber largestPermittedCheckpointValue) { ExtendedSequenceNumber largestPermittedCheckpointValue) {
when(mockGetRecordsRetrivalStrategy.getRecords(anyInt())).thenReturn( when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(
new GetRecordsResult().withRecords(records)); new GetRecordsResult().withRecords(records));
when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue); when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue);
when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue); when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue);
processTask.call(); processTask.call();
verify(throttlingReporter).success(); verify(throttlingReporter).success();
verify(throttlingReporter, never()).throttled(); verify(throttlingReporter, never()).throttled();
verify(mockGetRecordsRetrivalStrategy).getRecords(anyInt()); verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt());
ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class);
verify(mockRecordProcessor).processRecords(priCaptor.capture()); verify(mockRecordProcessor).processRecords(priCaptor.capture());
processedRecords = priCaptor.getValue().getRecords(); processedRecords = priCaptor.getValue().getRecords();