Only advance the shard iterator when we accept a result to return

This changes the retriever strategy to only accept the shard iterator
when we have accepted a result to return.  This is for the
asynchronous retriever where multiple threads may contend for the same
iterator slot.  This ensures only the one selected for the response will
advance the shard iterator.
This commit is contained in:
Pfifer, Justin 2017-09-26 12:39:53 -07:00
parent 23c46267cf
commit 4b20556f37
8 changed files with 228 additions and 42 deletions

View file

@ -6,7 +6,7 @@
<artifactId>amazon-kinesis-client</artifactId>
<packaging>jar</packaging>
<name>Amazon Kinesis Client Library for Java</name>
<version>1.8.4</version>
<version>1.8.5-SNAPSHOT</version>
<description>The Amazon Kinesis Client Library for Java enables Java developers to easily consume and process data
from Amazon Kinesis.
</description>

View file

@ -49,7 +49,7 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
private final ExecutorService executorService;
private final int retryGetRecordsInSeconds;
private final String shardId;
final Supplier<CompletionService<GetRecordsResult>> completionServiceSupplier;
final Supplier<CompletionService<DataFetcherResult>> completionServiceSupplier;
public AsynchronousGetRecordsRetrievalStrategy(@NonNull final KinesisDataFetcher dataFetcher,
final int retryGetRecordsInSeconds, final int maxGetRecordsThreadPool, String shardId) {
@ -63,7 +63,7 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
}
AsynchronousGetRecordsRetrievalStrategy(KinesisDataFetcher dataFetcher, ExecutorService executorService,
int retryGetRecordsInSeconds, Supplier<CompletionService<GetRecordsResult>> completionServiceSupplier,
int retryGetRecordsInSeconds, Supplier<CompletionService<DataFetcherResult>> completionServiceSupplier,
String shardId) {
this.dataFetcher = dataFetcher;
this.executorService = executorService;
@ -78,9 +78,9 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
throw new IllegalStateException("Strategy has been shutdown");
}
GetRecordsResult result = null;
CompletionService<GetRecordsResult> completionService = completionServiceSupplier.get();
Set<Future<GetRecordsResult>> futures = new HashSet<>();
Callable<GetRecordsResult> retrieverCall = createRetrieverCallable(maxRecords);
CompletionService<DataFetcherResult> completionService = completionServiceSupplier.get();
Set<Future<DataFetcherResult>> futures = new HashSet<>();
Callable<DataFetcherResult> retrieverCall = createRetrieverCallable(maxRecords);
while (true) {
try {
futures.add(completionService.submit(retrieverCall));
@ -89,10 +89,15 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
}
try {
Future<GetRecordsResult> resultFuture = completionService.poll(retryGetRecordsInSeconds,
Future<DataFetcherResult> resultFuture = completionService.poll(retryGetRecordsInSeconds,
TimeUnit.SECONDS);
if (resultFuture != null) {
result = resultFuture.get();
//
// Fix to ensure that we only let the shard iterator advance when we intend to return the result
// to the caller. This ensures that the shard iterator is consistently advance in step with
// what the caller sees.
//
result = resultFuture.get().accept();
break;
}
} catch (ExecutionException e) {
@ -106,7 +111,7 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
return result;
}
private Callable<GetRecordsResult> createRetrieverCallable(int maxRecords) {
private Callable<DataFetcherResult> createRetrieverCallable(int maxRecords) {
ThreadSafeMetricsDelegatingScope metricsScope = new ThreadSafeMetricsDelegatingScope(MetricsHelper.getMetricsScope());
return () -> {
try {

View file

@ -0,0 +1,37 @@
/*
* Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Amazon Software License
* (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
* http://aws.amazon.com/asl/ or in the "license" file accompanying this file. This file is distributed on an "AS IS"
* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
* Represents the result from the DataFetcher, and allows the receiver to accept a result
*/
public interface DataFetcherResult {
/**
* The result of the request to Kinesis
*
* @return The result of the request, this can be null if the request failed.
*/
GetRecordsResult getResult();
/**
* Accepts the result, and advances the shard iterator. A result from the data fetcher must be accepted before any
* further progress can be made.
*
* @return the result of the request, this can be null if the request failed.
*/
GetRecordsResult accept();
/**
* Indicates whether this result is at the end of the shard or not
*
* @return true if the result is at the end of a shard, false otherwise
*/
boolean isShardEnd();
}

View file

@ -14,6 +14,7 @@
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import lombok.Data;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -26,6 +27,7 @@ import com.amazonaws.services.kinesis.clientlibrary.proxies.MetricsCollectingKin
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
import java.util.Date;
import java.util.function.Consumer;
/**
* Used to get data from Amazon Kinesis. Tracks iterator state internally.
@ -57,30 +59,69 @@ class KinesisDataFetcher {
* @param maxRecords Max records to fetch
* @return list of records of up to maxRecords size
*/
public GetRecordsResult getRecords(int maxRecords) {
public DataFetcherResult getRecords(int maxRecords) {
if (!isInitialized) {
throw new IllegalArgumentException("KinesisDataFetcher.getRecords called before initialization.");
}
GetRecordsResult response = null;
DataFetcherResult response;
if (nextIterator != null) {
try {
response = kinesisProxy.get(nextIterator, maxRecords);
nextIterator = response.getNextShardIterator();
response = new AdvancingResult(kinesisProxy.get(nextIterator, maxRecords));
} catch (ResourceNotFoundException e) {
LOG.info("Caught ResourceNotFoundException when fetching records for shard " + shardId);
nextIterator = null;
}
if (nextIterator == null) {
isShardEndReached = true;
response = TERMINAL_RESULT;
}
} else {
isShardEndReached = true;
response = TERMINAL_RESULT;
}
return response;
}
final DataFetcherResult TERMINAL_RESULT = new DataFetcherResult() {
@Override
public GetRecordsResult getResult() {
return null;
}
@Override
public GetRecordsResult accept() {
isShardEndReached = true;
return getResult();
}
@Override
public boolean isShardEnd() {
return isShardEndReached;
}
};
@Data
private class AdvancingResult implements DataFetcherResult {
final GetRecordsResult result;
@Override
public GetRecordsResult getResult() {
return result;
}
@Override
public GetRecordsResult accept() {
nextIterator = result.getNextShardIterator();
if (nextIterator == null) {
isShardEndReached = true;
}
return getResult();
}
@Override
public boolean isShardEnd() {
return isShardEndReached;
}
}
/**
* Initializes this KinesisDataFetcher's iterator based on the checkpointed sequence number.
* @param initialCheckpoint Current checkpoint sequence number for this shard.

View file

@ -28,7 +28,7 @@ public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetriev
@Override
public GetRecordsResult getRecords(final int maxRecords) {
return dataFetcher.getRecords(maxRecords);
return dataFetcher.getRecords(maxRecords).accept();
}
@Override

View file

@ -36,7 +36,10 @@ import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import static org.junit.Assert.assertEquals;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
@ -58,17 +61,19 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
@Mock
private IKinesisProxy mockKinesisProxy;
@Mock
private ShardInfo mockShardInfo;
@Mock
private Supplier<CompletionService<GetRecordsResult>> completionServiceSupplier;
private Supplier<CompletionService<DataFetcherResult>> completionServiceSupplier;
@Mock
private DataFetcherResult result;
@Mock
private GetRecordsResult recordsResult;
private CompletionService<GetRecordsResult> completionService;
private CompletionService<DataFetcherResult> completionService;
private AsynchronousGetRecordsRetrievalStrategy getRecordsRetrivalStrategy;
private KinesisDataFetcher dataFetcher;
private GetRecordsResult result;
private ExecutorService executorService;
private RejectedExecutionHandler rejectedExecutionHandler;
private int numberOfRecords = 10;
@ -86,14 +91,15 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(),
rejectedExecutionHandler));
completionService = spy(new ExecutorCompletionService<GetRecordsResult>(executorService));
completionService = spy(new ExecutorCompletionService<DataFetcherResult>(executorService));
when(completionServiceSupplier.get()).thenReturn(completionService);
getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS, completionServiceSupplier, "shardId-0001");
result = null;
when(result.accept()).thenReturn(recordsResult);
}
@Test
public void oneRequestMultithreadTest() {
when(result.accept()).thenReturn(null);
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords));
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
@ -102,27 +108,25 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
@Test
public void multiRequestTest() {
result = mock(GetRecordsResult.class);
ExecutorCompletionService<GetRecordsResult> completionService1 = spy(new ExecutorCompletionService<GetRecordsResult>(executorService));
ExecutorCompletionService<DataFetcherResult> completionService1 = spy(new ExecutorCompletionService<DataFetcherResult>(executorService));
when(completionServiceSupplier.get()).thenReturn(completionService1);
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(numberOfRecords);
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
assertEquals(result, getRecordsResult);
assertThat(getRecordsResult, equalTo(recordsResult));
result = null;
ExecutorCompletionService<GetRecordsResult> completionService2 = spy(new ExecutorCompletionService<GetRecordsResult>(executorService));
when(result.accept()).thenReturn(null);
ExecutorCompletionService<DataFetcherResult> completionService2 = spy(new ExecutorCompletionService<DataFetcherResult>(executorService));
when(completionServiceSupplier.get()).thenReturn(completionService2);
getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
assertNull(getRecordsResult);
assertThat(getRecordsResult, nullValue(GetRecordsResult.class));
}
@Test
@Ignore
public void testInterrupted() throws InterruptedException, ExecutionException {
Future<GetRecordsResult> mockFuture = mock(Future.class);
Future<DataFetcherResult> mockFuture = mock(Future.class);
when(completionService.submit(any())).thenReturn(mockFuture);
when(completionService.poll()).thenReturn(mockFuture);
doThrow(InterruptedException.class).when(mockFuture).get();
@ -154,7 +158,7 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
}
@Override
public GetRecordsResult getRecords(final int maxRecords) {
public DataFetcherResult getRecords(final int maxRecords) {
try {
Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000);
} catch (InterruptedException e) {

View file

@ -53,19 +53,23 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
@Mock
private ExecutorService executorService;
@Mock
private Supplier<CompletionService<GetRecordsResult>> completionServiceSupplier;
private Supplier<CompletionService<DataFetcherResult>> completionServiceSupplier;
@Mock
private CompletionService<GetRecordsResult> completionService;
private CompletionService<DataFetcherResult> completionService;
@Mock
private Future<GetRecordsResult> successfulFuture;
private Future<DataFetcherResult> successfulFuture;
@Mock
private Future<GetRecordsResult> blockedFuture;
private Future<DataFetcherResult> blockedFuture;
@Mock
private DataFetcherResult dataFetcherResult;
@Mock
private GetRecordsResult expectedResults;
@Before
public void before() {
when(completionServiceSupplier.get()).thenReturn(completionService);
when(dataFetcherResult.getResult()).thenReturn(expectedResults);
when(dataFetcherResult.accept()).thenReturn(expectedResults);
}
@Test
@ -76,7 +80,7 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.get()).thenReturn(dataFetcherResult);
GetRecordsResult result = strategy.getRecords(10);
@ -97,7 +101,7 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.get()).thenReturn(dataFetcherResult);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false);
@ -133,7 +137,7 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenThrow(new RejectedExecutionException("Rejected!")).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.get()).thenReturn(dataFetcherResult);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false);

View file

@ -14,9 +14,20 @@
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
@ -39,12 +50,19 @@ import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxy;
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper;
import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
/**
* Unit tests for KinesisDataFetcher.
*/
@RunWith(MockitoJUnitRunner.class)
public class KinesisDataFetcherTest {
@Mock
private KinesisProxy kinesisProxy;
private static final int MAX_RECORDS = 1;
private static final String SHARD_ID = "shardId-1";
private static final String AT_SEQUENCE_NUMBER = ShardIteratorType.AT_SEQUENCE_NUMBER.toString();
@ -55,6 +73,7 @@ public class KinesisDataFetcherTest {
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON);
private static final InitialPositionInStreamExtended INITIAL_POSITION_AT_TIMESTAMP =
InitialPositionInStreamExtended.newInitialPositionAtTimestamp(new Date(1000));
;
/**
* @throws java.lang.Exception
@ -190,6 +209,82 @@ public class KinesisDataFetcherTest {
Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached());
}
@Test
public void testFetcherDoesNotAdvanceWithoutAccept() {
final String INITIAL_ITERATOR = "InitialIterator";
final String NEXT_ITERATOR_ONE = "NextIteratorOne";
final String NEXT_ITERATOR_TWO = "NextIteratorTwo";
when(kinesisProxy.getIterator(anyString(), anyString())).thenReturn(INITIAL_ITERATOR);
GetRecordsResult iteratorOneResults = mock(GetRecordsResult.class);
when(iteratorOneResults.getNextShardIterator()).thenReturn(NEXT_ITERATOR_ONE);
when(kinesisProxy.get(eq(INITIAL_ITERATOR), anyInt())).thenReturn(iteratorOneResults);
GetRecordsResult iteratorTwoResults = mock(GetRecordsResult.class);
when(kinesisProxy.get(eq(NEXT_ITERATOR_ONE), anyInt())).thenReturn(iteratorTwoResults);
when(iteratorTwoResults.getNextShardIterator()).thenReturn(NEXT_ITERATOR_TWO);
GetRecordsResult finalResult = mock(GetRecordsResult.class);
when(kinesisProxy.get(eq(NEXT_ITERATOR_TWO), anyInt())).thenReturn(finalResult);
when(finalResult.getNextShardIterator()).thenReturn(null);
KinesisDataFetcher dataFetcher = new KinesisDataFetcher(kinesisProxy, SHARD_INFO);
dataFetcher.initialize("TRIM_HORIZON", InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON));
assertNoAdvance(dataFetcher, iteratorOneResults, INITIAL_ITERATOR);
assertAdvanced(dataFetcher, iteratorOneResults, INITIAL_ITERATOR, NEXT_ITERATOR_ONE);
assertNoAdvance(dataFetcher, iteratorTwoResults, NEXT_ITERATOR_ONE);
assertAdvanced(dataFetcher, iteratorTwoResults, NEXT_ITERATOR_ONE, NEXT_ITERATOR_TWO);
assertNoAdvance(dataFetcher, finalResult, NEXT_ITERATOR_TWO);
assertAdvanced(dataFetcher, finalResult, NEXT_ITERATOR_TWO, null);
verify(kinesisProxy, times(2)).get(eq(INITIAL_ITERATOR), anyInt());
verify(kinesisProxy, times(2)).get(eq(NEXT_ITERATOR_ONE), anyInt());
verify(kinesisProxy, times(2)).get(eq(NEXT_ITERATOR_TWO), anyInt());
reset(kinesisProxy);
DataFetcherResult terminal = dataFetcher.getRecords(100);
assertThat(terminal.isShardEnd(), equalTo(true));
assertThat(terminal.getResult(), nullValue());
assertThat(terminal, equalTo(dataFetcher.TERMINAL_RESULT));
verify(kinesisProxy, never()).get(anyString(), anyInt());
}
private DataFetcherResult assertAdvanced(KinesisDataFetcher dataFetcher, GetRecordsResult expectedResult, String previousValue, String nextValue) {
DataFetcherResult acceptResult = dataFetcher.getRecords(100);
assertThat(acceptResult.getResult(), equalTo(expectedResult));
assertThat(dataFetcher.getNextIterator(), equalTo(previousValue));
assertThat(dataFetcher.isShardEndReached(), equalTo(false));
assertThat(acceptResult.accept(), equalTo(expectedResult));
assertThat(dataFetcher.getNextIterator(), equalTo(nextValue));
if (nextValue == null) {
assertThat(dataFetcher.isShardEndReached(), equalTo(true));
}
verify(kinesisProxy, times(2)).get(eq(previousValue), anyInt());
return acceptResult;
}
private DataFetcherResult assertNoAdvance(KinesisDataFetcher dataFetcher, GetRecordsResult expectedResult, String previousValue) {
assertThat(dataFetcher.getNextIterator(), equalTo(previousValue));
DataFetcherResult noAcceptResult = dataFetcher.getRecords(100);
assertThat(noAcceptResult.getResult(), equalTo(expectedResult));
assertThat(dataFetcher.getNextIterator(), equalTo(previousValue));
verify(kinesisProxy).get(eq(previousValue), anyInt());
return noAcceptResult;
}
private void testInitializeAndFetch(String iteratorType,
String seqNo,
InitialPositionInStreamExtended initialPositionInStream) throws Exception {