Merge pull request #230 from pfifer/data-fetcher-accept

Only advance the shard iterator when we accept a result to return
This commit is contained in:
Justin Pfifer 2017-09-26 13:47:57 -07:00 committed by GitHub
commit 6b25de917b
8 changed files with 228 additions and 42 deletions

View file

@ -6,7 +6,7 @@
<artifactId>amazon-kinesis-client</artifactId> <artifactId>amazon-kinesis-client</artifactId>
<packaging>jar</packaging> <packaging>jar</packaging>
<name>Amazon Kinesis Client Library for Java</name> <name>Amazon Kinesis Client Library for Java</name>
<version>1.8.4</version> <version>1.8.5-SNAPSHOT</version>
<description>The Amazon Kinesis Client Library for Java enables Java developers to easily consume and process data <description>The Amazon Kinesis Client Library for Java enables Java developers to easily consume and process data
from Amazon Kinesis. from Amazon Kinesis.
</description> </description>

View file

@ -49,7 +49,7 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
private final ExecutorService executorService; private final ExecutorService executorService;
private final int retryGetRecordsInSeconds; private final int retryGetRecordsInSeconds;
private final String shardId; private final String shardId;
final Supplier<CompletionService<GetRecordsResult>> completionServiceSupplier; final Supplier<CompletionService<DataFetcherResult>> completionServiceSupplier;
public AsynchronousGetRecordsRetrievalStrategy(@NonNull final KinesisDataFetcher dataFetcher, public AsynchronousGetRecordsRetrievalStrategy(@NonNull final KinesisDataFetcher dataFetcher,
final int retryGetRecordsInSeconds, final int maxGetRecordsThreadPool, String shardId) { final int retryGetRecordsInSeconds, final int maxGetRecordsThreadPool, String shardId) {
@ -63,7 +63,7 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
} }
AsynchronousGetRecordsRetrievalStrategy(KinesisDataFetcher dataFetcher, ExecutorService executorService, AsynchronousGetRecordsRetrievalStrategy(KinesisDataFetcher dataFetcher, ExecutorService executorService,
int retryGetRecordsInSeconds, Supplier<CompletionService<GetRecordsResult>> completionServiceSupplier, int retryGetRecordsInSeconds, Supplier<CompletionService<DataFetcherResult>> completionServiceSupplier,
String shardId) { String shardId) {
this.dataFetcher = dataFetcher; this.dataFetcher = dataFetcher;
this.executorService = executorService; this.executorService = executorService;
@ -78,9 +78,9 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
throw new IllegalStateException("Strategy has been shutdown"); throw new IllegalStateException("Strategy has been shutdown");
} }
GetRecordsResult result = null; GetRecordsResult result = null;
CompletionService<GetRecordsResult> completionService = completionServiceSupplier.get(); CompletionService<DataFetcherResult> completionService = completionServiceSupplier.get();
Set<Future<GetRecordsResult>> futures = new HashSet<>(); Set<Future<DataFetcherResult>> futures = new HashSet<>();
Callable<GetRecordsResult> retrieverCall = createRetrieverCallable(maxRecords); Callable<DataFetcherResult> retrieverCall = createRetrieverCallable(maxRecords);
while (true) { while (true) {
try { try {
futures.add(completionService.submit(retrieverCall)); futures.add(completionService.submit(retrieverCall));
@ -89,10 +89,15 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
} }
try { try {
Future<GetRecordsResult> resultFuture = completionService.poll(retryGetRecordsInSeconds, Future<DataFetcherResult> resultFuture = completionService.poll(retryGetRecordsInSeconds,
TimeUnit.SECONDS); TimeUnit.SECONDS);
if (resultFuture != null) { if (resultFuture != null) {
result = resultFuture.get(); //
// Fix to ensure that we only let the shard iterator advance when we intend to return the result
// to the caller. This ensures that the shard iterator is consistently advance in step with
// what the caller sees.
//
result = resultFuture.get().accept();
break; break;
} }
} catch (ExecutionException e) { } catch (ExecutionException e) {
@ -106,7 +111,7 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
return result; return result;
} }
private Callable<GetRecordsResult> createRetrieverCallable(int maxRecords) { private Callable<DataFetcherResult> createRetrieverCallable(int maxRecords) {
ThreadSafeMetricsDelegatingScope metricsScope = new ThreadSafeMetricsDelegatingScope(MetricsHelper.getMetricsScope()); ThreadSafeMetricsDelegatingScope metricsScope = new ThreadSafeMetricsDelegatingScope(MetricsHelper.getMetricsScope());
return () -> { return () -> {
try { try {

View file

@ -0,0 +1,37 @@
/*
* Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Amazon Software License
* (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
* http://aws.amazon.com/asl/ or in the "license" file accompanying this file. This file is distributed on an "AS IS"
* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
* Represents the result from the DataFetcher, and allows the receiver to accept a result
*/
public interface DataFetcherResult {
/**
* The result of the request to Kinesis
*
* @return The result of the request, this can be null if the request failed.
*/
GetRecordsResult getResult();
/**
* Accepts the result, and advances the shard iterator. A result from the data fetcher must be accepted before any
* further progress can be made.
*
* @return the result of the request, this can be null if the request failed.
*/
GetRecordsResult accept();
/**
* Indicates whether this result is at the end of the shard or not
*
* @return true if the result is at the end of a shard, false otherwise
*/
boolean isShardEnd();
}

View file

@ -14,6 +14,7 @@
*/ */
package com.amazonaws.services.kinesis.clientlibrary.lib.worker; package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import lombok.Data;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -26,6 +27,7 @@ import com.amazonaws.services.kinesis.clientlibrary.proxies.MetricsCollectingKin
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
import java.util.Date; import java.util.Date;
import java.util.function.Consumer;
/** /**
* Used to get data from Amazon Kinesis. Tracks iterator state internally. * Used to get data from Amazon Kinesis. Tracks iterator state internally.
@ -57,30 +59,69 @@ class KinesisDataFetcher {
* @param maxRecords Max records to fetch * @param maxRecords Max records to fetch
* @return list of records of up to maxRecords size * @return list of records of up to maxRecords size
*/ */
public GetRecordsResult getRecords(int maxRecords) { public DataFetcherResult getRecords(int maxRecords) {
if (!isInitialized) { if (!isInitialized) {
throw new IllegalArgumentException("KinesisDataFetcher.getRecords called before initialization."); throw new IllegalArgumentException("KinesisDataFetcher.getRecords called before initialization.");
} }
GetRecordsResult response = null; DataFetcherResult response;
if (nextIterator != null) { if (nextIterator != null) {
try { try {
response = kinesisProxy.get(nextIterator, maxRecords); response = new AdvancingResult(kinesisProxy.get(nextIterator, maxRecords));
nextIterator = response.getNextShardIterator();
} catch (ResourceNotFoundException e) { } catch (ResourceNotFoundException e) {
LOG.info("Caught ResourceNotFoundException when fetching records for shard " + shardId); LOG.info("Caught ResourceNotFoundException when fetching records for shard " + shardId);
nextIterator = null; response = TERMINAL_RESULT;
}
if (nextIterator == null) {
isShardEndReached = true;
} }
} else { } else {
isShardEndReached = true; response = TERMINAL_RESULT;
} }
return response; return response;
} }
final DataFetcherResult TERMINAL_RESULT = new DataFetcherResult() {
@Override
public GetRecordsResult getResult() {
return null;
}
@Override
public GetRecordsResult accept() {
isShardEndReached = true;
return getResult();
}
@Override
public boolean isShardEnd() {
return isShardEndReached;
}
};
@Data
private class AdvancingResult implements DataFetcherResult {
final GetRecordsResult result;
@Override
public GetRecordsResult getResult() {
return result;
}
@Override
public GetRecordsResult accept() {
nextIterator = result.getNextShardIterator();
if (nextIterator == null) {
isShardEndReached = true;
}
return getResult();
}
@Override
public boolean isShardEnd() {
return isShardEndReached;
}
}
/** /**
* Initializes this KinesisDataFetcher's iterator based on the checkpointed sequence number. * Initializes this KinesisDataFetcher's iterator based on the checkpointed sequence number.
* @param initialCheckpoint Current checkpoint sequence number for this shard. * @param initialCheckpoint Current checkpoint sequence number for this shard.

View file

@ -28,7 +28,7 @@ public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetriev
@Override @Override
public GetRecordsResult getRecords(final int maxRecords) { public GetRecordsResult getRecords(final int maxRecords) {
return dataFetcher.getRecords(maxRecords); return dataFetcher.getRecords(maxRecords).accept();
} }
@Override @Override

View file

@ -36,7 +36,10 @@ import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.junit.Assert.assertEquals;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
@ -58,17 +61,19 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
@Mock @Mock
private IKinesisProxy mockKinesisProxy; private IKinesisProxy mockKinesisProxy;
@Mock @Mock
private ShardInfo mockShardInfo; private ShardInfo mockShardInfo;
@Mock @Mock
private Supplier<CompletionService<GetRecordsResult>> completionServiceSupplier; private Supplier<CompletionService<DataFetcherResult>> completionServiceSupplier;
@Mock
private DataFetcherResult result;
@Mock
private GetRecordsResult recordsResult;
private CompletionService<GetRecordsResult> completionService; private CompletionService<DataFetcherResult> completionService;
private AsynchronousGetRecordsRetrievalStrategy getRecordsRetrivalStrategy; private AsynchronousGetRecordsRetrievalStrategy getRecordsRetrivalStrategy;
private KinesisDataFetcher dataFetcher; private KinesisDataFetcher dataFetcher;
private GetRecordsResult result;
private ExecutorService executorService; private ExecutorService executorService;
private RejectedExecutionHandler rejectedExecutionHandler; private RejectedExecutionHandler rejectedExecutionHandler;
private int numberOfRecords = 10; private int numberOfRecords = 10;
@ -86,14 +91,15 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
new LinkedBlockingQueue<>(1), new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(), new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(),
rejectedExecutionHandler)); rejectedExecutionHandler));
completionService = spy(new ExecutorCompletionService<GetRecordsResult>(executorService)); completionService = spy(new ExecutorCompletionService<DataFetcherResult>(executorService));
when(completionServiceSupplier.get()).thenReturn(completionService); when(completionServiceSupplier.get()).thenReturn(completionService);
getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS, completionServiceSupplier, "shardId-0001"); getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS, completionServiceSupplier, "shardId-0001");
result = null; when(result.accept()).thenReturn(recordsResult);
} }
@Test @Test
public void oneRequestMultithreadTest() { public void oneRequestMultithreadTest() {
when(result.accept()).thenReturn(null);
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords)); verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords));
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any()); verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
@ -102,27 +108,25 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
@Test @Test
public void multiRequestTest() { public void multiRequestTest() {
result = mock(GetRecordsResult.class); ExecutorCompletionService<DataFetcherResult> completionService1 = spy(new ExecutorCompletionService<DataFetcherResult>(executorService));
ExecutorCompletionService<GetRecordsResult> completionService1 = spy(new ExecutorCompletionService<GetRecordsResult>(executorService));
when(completionServiceSupplier.get()).thenReturn(completionService1); when(completionServiceSupplier.get()).thenReturn(completionService1);
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(numberOfRecords); verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(numberOfRecords);
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any()); verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
assertEquals(result, getRecordsResult); assertThat(getRecordsResult, equalTo(recordsResult));
result = null; when(result.accept()).thenReturn(null);
ExecutorCompletionService<GetRecordsResult> completionService2 = spy(new ExecutorCompletionService<GetRecordsResult>(executorService)); ExecutorCompletionService<DataFetcherResult> completionService2 = spy(new ExecutorCompletionService<DataFetcherResult>(executorService));
when(completionServiceSupplier.get()).thenReturn(completionService2); when(completionServiceSupplier.get()).thenReturn(completionService2);
getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords); getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
assertNull(getRecordsResult); assertThat(getRecordsResult, nullValue(GetRecordsResult.class));
} }
@Test @Test
@Ignore @Ignore
public void testInterrupted() throws InterruptedException, ExecutionException { public void testInterrupted() throws InterruptedException, ExecutionException {
Future<GetRecordsResult> mockFuture = mock(Future.class); Future<DataFetcherResult> mockFuture = mock(Future.class);
when(completionService.submit(any())).thenReturn(mockFuture); when(completionService.submit(any())).thenReturn(mockFuture);
when(completionService.poll()).thenReturn(mockFuture); when(completionService.poll()).thenReturn(mockFuture);
doThrow(InterruptedException.class).when(mockFuture).get(); doThrow(InterruptedException.class).when(mockFuture).get();
@ -154,7 +158,7 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
} }
@Override @Override
public GetRecordsResult getRecords(final int maxRecords) { public DataFetcherResult getRecords(final int maxRecords) {
try { try {
Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000); Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000);
} catch (InterruptedException e) { } catch (InterruptedException e) {

View file

@ -53,19 +53,23 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
@Mock @Mock
private ExecutorService executorService; private ExecutorService executorService;
@Mock @Mock
private Supplier<CompletionService<GetRecordsResult>> completionServiceSupplier; private Supplier<CompletionService<DataFetcherResult>> completionServiceSupplier;
@Mock @Mock
private CompletionService<GetRecordsResult> completionService; private CompletionService<DataFetcherResult> completionService;
@Mock @Mock
private Future<GetRecordsResult> successfulFuture; private Future<DataFetcherResult> successfulFuture;
@Mock @Mock
private Future<GetRecordsResult> blockedFuture; private Future<DataFetcherResult> blockedFuture;
@Mock
private DataFetcherResult dataFetcherResult;
@Mock @Mock
private GetRecordsResult expectedResults; private GetRecordsResult expectedResults;
@Before @Before
public void before() { public void before() {
when(completionServiceSupplier.get()).thenReturn(completionService); when(completionServiceSupplier.get()).thenReturn(completionService);
when(dataFetcherResult.getResult()).thenReturn(expectedResults);
when(dataFetcherResult.accept()).thenReturn(expectedResults);
} }
@Test @Test
@ -76,7 +80,7 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
when(executorService.isShutdown()).thenReturn(false); when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(successfulFuture); when(completionService.submit(any())).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(successfulFuture); when(completionService.poll(anyLong(), any())).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults); when(successfulFuture.get()).thenReturn(dataFetcherResult);
GetRecordsResult result = strategy.getRecords(10); GetRecordsResult result = strategy.getRecords(10);
@ -97,7 +101,7 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
when(executorService.isShutdown()).thenReturn(false); when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenReturn(successfulFuture); when(completionService.submit(any())).thenReturn(blockedFuture).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture); when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults); when(successfulFuture.get()).thenReturn(dataFetcherResult);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false); when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true); when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false); when(successfulFuture.isCancelled()).thenReturn(false);
@ -133,7 +137,7 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
when(executorService.isShutdown()).thenReturn(false); when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenThrow(new RejectedExecutionException("Rejected!")).thenReturn(successfulFuture); when(completionService.submit(any())).thenReturn(blockedFuture).thenThrow(new RejectedExecutionException("Rejected!")).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(null).thenReturn(successfulFuture); when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults); when(successfulFuture.get()).thenReturn(dataFetcherResult);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false); when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true); when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false); when(successfulFuture.isCancelled()).thenReturn(false);

View file

@ -14,9 +14,20 @@
*/ */
package com.amazonaws.services.kinesis.clientlibrary.lib.worker; package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.util.ArrayList; import java.util.ArrayList;
@ -39,12 +50,19 @@ import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxy;
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper;
import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory; import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
/** /**
* Unit tests for KinesisDataFetcher. * Unit tests for KinesisDataFetcher.
*/ */
@RunWith(MockitoJUnitRunner.class)
public class KinesisDataFetcherTest { public class KinesisDataFetcherTest {
@Mock
private KinesisProxy kinesisProxy;
private static final int MAX_RECORDS = 1; private static final int MAX_RECORDS = 1;
private static final String SHARD_ID = "shardId-1"; private static final String SHARD_ID = "shardId-1";
private static final String AT_SEQUENCE_NUMBER = ShardIteratorType.AT_SEQUENCE_NUMBER.toString(); private static final String AT_SEQUENCE_NUMBER = ShardIteratorType.AT_SEQUENCE_NUMBER.toString();
@ -55,6 +73,7 @@ public class KinesisDataFetcherTest {
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON); InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON);
private static final InitialPositionInStreamExtended INITIAL_POSITION_AT_TIMESTAMP = private static final InitialPositionInStreamExtended INITIAL_POSITION_AT_TIMESTAMP =
InitialPositionInStreamExtended.newInitialPositionAtTimestamp(new Date(1000)); InitialPositionInStreamExtended.newInitialPositionAtTimestamp(new Date(1000));
;
/** /**
* @throws java.lang.Exception * @throws java.lang.Exception
@ -190,6 +209,82 @@ public class KinesisDataFetcherTest {
Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached()); Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached());
} }
@Test
public void testFetcherDoesNotAdvanceWithoutAccept() {
final String INITIAL_ITERATOR = "InitialIterator";
final String NEXT_ITERATOR_ONE = "NextIteratorOne";
final String NEXT_ITERATOR_TWO = "NextIteratorTwo";
when(kinesisProxy.getIterator(anyString(), anyString())).thenReturn(INITIAL_ITERATOR);
GetRecordsResult iteratorOneResults = mock(GetRecordsResult.class);
when(iteratorOneResults.getNextShardIterator()).thenReturn(NEXT_ITERATOR_ONE);
when(kinesisProxy.get(eq(INITIAL_ITERATOR), anyInt())).thenReturn(iteratorOneResults);
GetRecordsResult iteratorTwoResults = mock(GetRecordsResult.class);
when(kinesisProxy.get(eq(NEXT_ITERATOR_ONE), anyInt())).thenReturn(iteratorTwoResults);
when(iteratorTwoResults.getNextShardIterator()).thenReturn(NEXT_ITERATOR_TWO);
GetRecordsResult finalResult = mock(GetRecordsResult.class);
when(kinesisProxy.get(eq(NEXT_ITERATOR_TWO), anyInt())).thenReturn(finalResult);
when(finalResult.getNextShardIterator()).thenReturn(null);
KinesisDataFetcher dataFetcher = new KinesisDataFetcher(kinesisProxy, SHARD_INFO);
dataFetcher.initialize("TRIM_HORIZON", InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON));
assertNoAdvance(dataFetcher, iteratorOneResults, INITIAL_ITERATOR);
assertAdvanced(dataFetcher, iteratorOneResults, INITIAL_ITERATOR, NEXT_ITERATOR_ONE);
assertNoAdvance(dataFetcher, iteratorTwoResults, NEXT_ITERATOR_ONE);
assertAdvanced(dataFetcher, iteratorTwoResults, NEXT_ITERATOR_ONE, NEXT_ITERATOR_TWO);
assertNoAdvance(dataFetcher, finalResult, NEXT_ITERATOR_TWO);
assertAdvanced(dataFetcher, finalResult, NEXT_ITERATOR_TWO, null);
verify(kinesisProxy, times(2)).get(eq(INITIAL_ITERATOR), anyInt());
verify(kinesisProxy, times(2)).get(eq(NEXT_ITERATOR_ONE), anyInt());
verify(kinesisProxy, times(2)).get(eq(NEXT_ITERATOR_TWO), anyInt());
reset(kinesisProxy);
DataFetcherResult terminal = dataFetcher.getRecords(100);
assertThat(terminal.isShardEnd(), equalTo(true));
assertThat(terminal.getResult(), nullValue());
assertThat(terminal, equalTo(dataFetcher.TERMINAL_RESULT));
verify(kinesisProxy, never()).get(anyString(), anyInt());
}
private DataFetcherResult assertAdvanced(KinesisDataFetcher dataFetcher, GetRecordsResult expectedResult, String previousValue, String nextValue) {
DataFetcherResult acceptResult = dataFetcher.getRecords(100);
assertThat(acceptResult.getResult(), equalTo(expectedResult));
assertThat(dataFetcher.getNextIterator(), equalTo(previousValue));
assertThat(dataFetcher.isShardEndReached(), equalTo(false));
assertThat(acceptResult.accept(), equalTo(expectedResult));
assertThat(dataFetcher.getNextIterator(), equalTo(nextValue));
if (nextValue == null) {
assertThat(dataFetcher.isShardEndReached(), equalTo(true));
}
verify(kinesisProxy, times(2)).get(eq(previousValue), anyInt());
return acceptResult;
}
private DataFetcherResult assertNoAdvance(KinesisDataFetcher dataFetcher, GetRecordsResult expectedResult, String previousValue) {
assertThat(dataFetcher.getNextIterator(), equalTo(previousValue));
DataFetcherResult noAcceptResult = dataFetcher.getRecords(100);
assertThat(noAcceptResult.getResult(), equalTo(expectedResult));
assertThat(dataFetcher.getNextIterator(), equalTo(previousValue));
verify(kinesisProxy).get(eq(previousValue), anyInt());
return noAcceptResult;
}
private void testInitializeAndFetch(String iteratorType, private void testInitializeAndFetch(String iteratorType,
String seqNo, String seqNo,
InitialPositionInStreamExtended initialPositionInStream) throws Exception { InitialPositionInStreamExtended initialPositionInStream) throws Exception {