Fixed a Spelling Error, and Slight Refactor for Tests

* Renamed the retrieval strategy classes to fix a spelling error.
* Modified the strategy interface to support shutdown, and determination
  of whether a strategy has been shutdown.
* Moved the existing tests for the async strategy to an integration
  test.
* Modified the async strategy to allow injection of more state
  components
* Modified the async strategy to throw an exception if an attempt is
  made to use it after shutdown.

cr https://code.amazon.com/reviews/CR-590341
This commit is contained in:
Pfifer, Justin 2017-09-01 09:17:59 -07:00
parent e0ae95dd09
commit 7472cec60c
9 changed files with 281 additions and 94 deletions

View file

@ -2,7 +2,6 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
@ -16,7 +15,6 @@ import java.util.concurrent.TimeUnit;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.apachecommons.CommonsLog;
@ -24,58 +22,53 @@ import lombok.extern.apachecommons.CommonsLog;
*
*/
@CommonsLog
public class AsynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrivalStrategy {
public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy {
private static final int TIME_TO_KEEP_ALIVE = 5;
private static final int CORE_THREAD_POOL_COUNT = 1;
private final KinesisDataFetcher dataFetcher;
private final ExecutorService executorService;
private final int retryGetRecordsInSeconds;
@Getter
private final CompletionService<GetRecordsResult> completionService;
private final String shardId;
final CompletionService<GetRecordsResult> completionService;
public AsynchronousGetRecordsRetrivalStrategy(@NonNull final KinesisDataFetcher dataFetcher,
final int retryGetRecordsInSeconds,
final int maxGetRecordsThreadPool) {
this (dataFetcher,
new ThreadPoolExecutor(
CORE_THREAD_POOL_COUNT,
maxGetRecordsThreadPool,
TIME_TO_KEEP_ALIVE,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(),
new ThreadPoolExecutor.AbortPolicy()),
retryGetRecordsInSeconds);
public AsynchronousGetRecordsRetrievalStrategy(@NonNull final KinesisDataFetcher dataFetcher,
final int retryGetRecordsInSeconds, final int maxGetRecordsThreadPool, String shardId) {
this(dataFetcher, buildExector(maxGetRecordsThreadPool, shardId), retryGetRecordsInSeconds, shardId);
}
public AsynchronousGetRecordsRetrivalStrategy(final KinesisDataFetcher dataFetcher,
final ExecutorService executorService,
final int retryGetRecordsInSeconds) {
public AsynchronousGetRecordsRetrievalStrategy(final KinesisDataFetcher dataFetcher,
final ExecutorService executorService, final int retryGetRecordsInSeconds, String shardId) {
this(dataFetcher, executorService, retryGetRecordsInSeconds, new ExecutorCompletionService<>(executorService),
shardId);
}
AsynchronousGetRecordsRetrievalStrategy(KinesisDataFetcher dataFetcher, ExecutorService executorService,
int retryGetRecordsInSeconds, CompletionService<GetRecordsResult> completionService, String shardId) {
this.dataFetcher = dataFetcher;
this.executorService = executorService;
this.retryGetRecordsInSeconds = retryGetRecordsInSeconds;
this.completionService = new ExecutorCompletionService<>(executorService);
this.completionService = completionService;
this.shardId = shardId;
}
@Override
public GetRecordsResult getRecords(final int maxRecords) {
if (executorService.isShutdown()) {
throw new IllegalStateException("Strategy has been shutdown");
}
GetRecordsResult result = null;
Set<Future<GetRecordsResult>> futures = new HashSet<>();
while (true) {
try {
futures.add(completionService.submit(new Callable<GetRecordsResult>() {
@Override
public GetRecordsResult call() throws Exception {
return dataFetcher.getRecords(maxRecords);
}
}));
futures.add(completionService.submit(() -> dataFetcher.getRecords(maxRecords)));
} catch (RejectedExecutionException e) {
log.warn("Out of resources, unable to start additional requests.");
}
try {
Future<GetRecordsResult> resultFuture = completionService.poll(retryGetRecordsInSeconds, TimeUnit.SECONDS);
Future<GetRecordsResult> resultFuture = completionService.poll(retryGetRecordsInSeconds,
TimeUnit.SECONDS);
if (resultFuture != null) {
result = resultFuture.get();
break;
@ -97,7 +90,21 @@ public class AsynchronousGetRecordsRetrivalStrategy implements GetRecordsRetriva
return result;
}
@Override
public void shutdown() {
executorService.shutdownNow();
}
@Override
public boolean isShutdown() {
return executorService.isShutdown();
}
private static ExecutorService buildExector(int maxGetRecordsThreadPool, String shardId) {
String threadNameFormat = "get-records-worker-" + shardId + "-%d";
return new ThreadPoolExecutor(CORE_THREAD_POOL_COUNT, maxGetRecordsThreadPool, TIME_TO_KEEP_ALIVE,
TimeUnit.SECONDS, new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadNameFormat).build(),
new ThreadPoolExecutor.AbortPolicy());
}
}

View file

@ -0,0 +1,33 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
* Represents a strategy to retrieve records from Kinesis. Allows for variations on how records are retrieved from
* Kinesis.
*/
public interface GetRecordsRetrievalStrategy {
/**
* Gets a set of records from Kinesis.
*
* @param maxRecords
* passed to Kinesis, and can be used to restrict the number of records returned from Kinesis.
* @return the resulting records.
* @throws IllegalStateException
* if the strategy has been shutdown.
*/
GetRecordsResult getRecords(int maxRecords);
/**
* Releases any resources used by the strategy. Once the strategy is shutdown it is no longer safe to call
* {@link #getRecords(int)}.
*/
void shutdown();
/**
* Returns whether this strategy has been shutdown.
*
* @return true if the strategy has been shutdown, false otherwise.
*/
boolean isShutdown();
}

View file

@ -1,10 +0,0 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
*
*/
public interface GetRecordsRetrivalStrategy {
GetRecordsResult getRecords(int maxRecords);
}

View file

@ -62,7 +62,7 @@ class ProcessTask implements ITask {
private final Shard shard;
private final ThrottlingReporter throttlingReporter;
private final GetRecordsRetrivalStrategy getRecordsRetrivalStrategy;
private final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
/**
* @param shardInfo
@ -83,7 +83,7 @@ class ProcessTask implements ITask {
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) {
this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist,
new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), new SynchronousGetRecordsRetrivalStrategy(dataFetcher));
new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), new SynchronousGetRecordsRetrievalStrategy(dataFetcher));
}
/**
@ -105,7 +105,7 @@ class ProcessTask implements ITask {
public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor,
RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher,
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist,
ThrottlingReporter throttlingReporter, GetRecordsRetrivalStrategy getRecordsRetrivalStrategy) {
ThrottlingReporter throttlingReporter, GetRecordsRetrievalStrategy getRecordsRetrievalStrategy) {
super();
this.shardInfo = shardInfo;
this.recordProcessor = recordProcessor;
@ -115,7 +115,7 @@ class ProcessTask implements ITask {
this.backoffTimeMillis = backoffTimeMillis;
this.throttlingReporter = throttlingReporter;
IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy();
this.getRecordsRetrivalStrategy = getRecordsRetrivalStrategy;
this.getRecordsRetrievalStrategy = getRecordsRetrievalStrategy;
// If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for
// this ProcessTask. In this case, duplicate KPL user records in the event of resharding will
// not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if
@ -371,7 +371,7 @@ class ProcessTask implements ITask {
* @return list of data records from Kinesis
*/
private GetRecordsResult getRecordsResultAndRecordMillisBehindLatest() {
final GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(streamConfig.getMaxRecords());
final GetRecordsResult getRecordsResult = getRecordsRetrievalStrategy.getRecords(streamConfig.getMaxRecords());
if (getRecordsResult == null) {
// Stream no longer exists

View file

@ -8,7 +8,7 @@ import lombok.NonNull;
*
*/
@Data
public class SynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrivalStrategy {
public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy {
@NonNull
private final KinesisDataFetcher dataFetcher;
@ -16,4 +16,16 @@ public class SynchronousGetRecordsRetrivalStrategy implements GetRecordsRetrival
public GetRecordsResult getRecords(final int maxRecords) {
return dataFetcher.getRecords(maxRecords);
}
@Override
public void shutdown() {
//
// Does nothing as this retriever doesn't manage any resources
//
}
@Override
public boolean isShutdown() {
return false;
}
}

View file

@ -1,17 +1,27 @@
/*
* Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.mockito.Mock;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
@ -22,23 +32,19 @@ import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.runners.MockitoJUnitRunner;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
/**
*
*/
@RunWith(MockitoJUnitRunner.class)
public class AsynchronousGetRecordsRetrivalStrategyTest {
private static final int CORE_POOL_SIZE = 1;
private static final int MAX_POOL_SIZE = 2;
private static final int TIME_TO_LIVE = 5;
@ -51,7 +57,7 @@ public class AsynchronousGetRecordsRetrivalStrategyTest {
@Mock
private ShardInfo mockShardInfo;
private AsynchronousGetRecordsRetrivalStrategy getRecordsRetrivalStrategy;
private AsynchronousGetRecordsRetrievalStrategy getRecordsRetrivalStrategy;
private KinesisDataFetcher dataFetcher;
private GetRecordsResult result;
private ExecutorService executorService;
@ -71,8 +77,8 @@ public class AsynchronousGetRecordsRetrivalStrategyTest {
new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(),
rejectedExecutionHandler));
getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrivalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS);
completionService = spy(getRecordsRetrivalStrategy.getCompletionService());
getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS, "shardId-0001");
completionService = spy(getRecordsRetrivalStrategy.completionService);
result = null;
}
@ -98,17 +104,18 @@ public class AsynchronousGetRecordsRetrivalStrategyTest {
assertNull(getRecordsResult);
}
/*@Test
@Test
@Ignore
public void testInterrupted() throws InterruptedException, ExecutionException {
Future<GetRecordsResult> mockFuture = mock(Future.class);
System.out.println(completionService);
when(completionService.submit(any())).thenReturn(mockFuture);
when(completionService.poll()).thenReturn(mockFuture);
doThrow(InterruptedException.class).when(mockFuture).get();
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(mockFuture).get();
assertNull(getRecordsResult);
}*/
}
private int getLeastNumberOfCalls() {
int leastNumberOfCalls = 0;
@ -142,4 +149,5 @@ public class AsynchronousGetRecordsRetrivalStrategyTest {
return result;
}
}
}

View file

@ -0,0 +1,137 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
*
*/
@RunWith(MockitoJUnitRunner.class)
public class AsynchronousGetRecordsRetrievalStrategyTest {
private static final long RETRY_GET_RECORDS_IN_SECONDS = 5;
private static final String SHARD_ID = "ShardId-0001";
@Mock
private KinesisDataFetcher dataFetcher;
@Mock
private ExecutorService executorService;
@Mock
private CompletionService<GetRecordsResult> completionService;
@Mock
private Future<GetRecordsResult> successfulFuture;
@Mock
private Future<GetRecordsResult> blockedFuture;
@Mock
private GetRecordsResult expectedResults;
@Test
public void testSingleSuccessfulRequestFuture() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
GetRecordsResult result = strategy.getRecords(10);
verify(executorService).isShutdown();
verify(completionService).submit(any());
verify(completionService).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).get();
verify(successfulFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(completionService, never()).take();
assertThat(result, equalTo(expectedResults));
}
@Test
public void testBlockedAndSuccessfulFuture() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false);
when(blockedFuture.isCancelled()).thenReturn(true);
GetRecordsResult actualResults = strategy.getRecords(10);
verify(completionService, times(2)).submit(any());
verify(completionService, times(2)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).get();
verify(blockedFuture, never()).get();
verify(successfulFuture).cancel(eq(true));
verify(blockedFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(blockedFuture).isCancelled();
verify(completionService).take();
assertThat(actualResults, equalTo(expectedResults));
}
@Test(expected = IllegalStateException.class)
public void testStrategyIsShutdown() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(true);
strategy.getRecords(10);
}
@Test
public void testPoolOutOfResources() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenThrow(new RejectedExecutionException("Rejected!")).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false);
when(blockedFuture.isCancelled()).thenReturn(true);
GetRecordsResult actualResult = strategy.getRecords(10);
verify(completionService, times(3)).submit(any());
verify(completionService, times(3)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).cancel(eq(true));
verify(blockedFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(blockedFuture).isCancelled();
verify(completionService).take();
assertThat(actualResult, equalTo(expectedResults));
}
}

View file

@ -117,7 +117,7 @@ public class KinesisDataFetcherTest {
ICheckpoint checkpoint = mock(ICheckpoint.class);
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO);
GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(fetcher);
GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher);
String iteratorA = "foo";
String iteratorB = "bar";
@ -139,10 +139,10 @@ public class KinesisDataFetcherTest {
fetcher.initialize(seqA, null);
fetcher.advanceIteratorTo(seqA, null);
Assert.assertEquals(recordsA, getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords());
Assert.assertEquals(recordsA, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords());
fetcher.advanceIteratorTo(seqB, null);
Assert.assertEquals(recordsB, getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords());
Assert.assertEquals(recordsB, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords());
}
@Test
@ -182,9 +182,9 @@ public class KinesisDataFetcherTest {
// Create data fectcher and initialize it with latest type checkpoint
KinesisDataFetcher dataFetcher = new KinesisDataFetcher(mockProxy, SHARD_INFO);
dataFetcher.initialize(SentinelCheckpoint.LATEST.toString(), INITIAL_POSITION_LATEST);
GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(dataFetcher);
GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(dataFetcher);
// Call getRecords of dataFetcher which will throw an exception
getRecordsRetrivalStrategy.getRecords(maxRecords);
getRecordsRetrievalStrategy.getRecords(maxRecords);
// Test shard has reached the end
Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached());
@ -208,9 +208,9 @@ public class KinesisDataFetcherTest {
when(checkpoint.getCheckpoint(SHARD_ID)).thenReturn(new ExtendedSequenceNumber(seqNo));
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO);
GetRecordsRetrivalStrategy getRecordsRetrivalStrategy = new SynchronousGetRecordsRetrivalStrategy(fetcher);
GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher);
fetcher.initialize(seqNo, initialPositionInStream);
List<Record> actualRecords = getRecordsRetrivalStrategy.getRecords(MAX_RECORDS).getRecords();
List<Record> actualRecords = getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords();
Assert.assertEquals(expectedRecords, actualRecords);
}

View file

@ -77,7 +77,7 @@ public class ProcessTaskTest {
@Mock
private ThrottlingReporter throttlingReporter;
@Mock
private GetRecordsRetrivalStrategy mockGetRecordsRetrivalStrategy;
private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy;
private List<Record> processedRecords;
private ExtendedSequenceNumber newLargestPermittedCheckpointValue;
@ -96,20 +96,20 @@ public class ProcessTaskTest {
final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null);
processTask = new ProcessTask(
shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrivalStrategy);
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy);
}
@Test
public void testProcessTaskWithProvisionedThroughputExceededException() {
// Set data fetcher to throw exception
doReturn(false).when(mockDataFetcher).isShardEndReached();
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrivalStrategy)
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy)
.getRecords(maxRecords);
TaskResult result = processTask.call();
verify(throttlingReporter).throttled();
verify(throttlingReporter, never()).success();
verify(mockGetRecordsRetrivalStrategy).getRecords(eq(maxRecords));
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
assertTrue("Result should contain ProvisionedThroughputExceededException",
result.getException() instanceof ProvisionedThroughputExceededException);
}
@ -117,10 +117,10 @@ public class ProcessTaskTest {
@Test
public void testProcessTaskWithNonExistentStream() {
// Data fetcher returns a null Result when the stream does not exist
doReturn(null).when(mockGetRecordsRetrivalStrategy).getRecords(maxRecords);
doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords);
TaskResult result = processTask.call();
verify(mockGetRecordsRetrivalStrategy).getRecords(eq(maxRecords));
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
assertNull("Task should not throw an exception", result.getException());
}
@ -304,14 +304,14 @@ public class ProcessTaskTest {
private void testWithRecords(List<Record> records,
ExtendedSequenceNumber lastCheckpointValue,
ExtendedSequenceNumber largestPermittedCheckpointValue) {
when(mockGetRecordsRetrivalStrategy.getRecords(anyInt())).thenReturn(
when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(
new GetRecordsResult().withRecords(records));
when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue);
when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue);
processTask.call();
verify(throttlingReporter).success();
verify(throttlingReporter, never()).throttled();
verify(mockGetRecordsRetrivalStrategy).getRecords(anyInt());
verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt());
ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class);
verify(mockRecordProcessor).processRecords(priCaptor.capture());
processedRecords = priCaptor.getValue().getRecords();