Handle Expired Iterators Correctly

Fix for the lease losses in the PrefetchCache and AsyncGetRecordsStrategy caused due to ExpiredIteratorException. (#263)
This commit is contained in:
Sahil Palvia 2017-11-08 12:03:09 -08:00 committed by Justin Pfifer
parent 3de901ea93
commit 5c3ff2b31e
15 changed files with 248 additions and 44 deletions

View file

@ -16,7 +16,6 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService; import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -31,6 +30,7 @@ import java.util.function.Supplier;
import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper;
import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingScope; import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingScope;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
@ -81,6 +81,7 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
CompletionService<DataFetcherResult> completionService = completionServiceSupplier.get(); CompletionService<DataFetcherResult> completionService = completionServiceSupplier.get();
Set<Future<DataFetcherResult>> futures = new HashSet<>(); Set<Future<DataFetcherResult>> futures = new HashSet<>();
Callable<DataFetcherResult> retrieverCall = createRetrieverCallable(maxRecords); Callable<DataFetcherResult> retrieverCall = createRetrieverCallable(maxRecords);
try {
while (true) { while (true) {
try { try {
futures.add(completionService.submit(retrieverCall)); futures.add(completionService.submit(retrieverCall));
@ -101,13 +102,18 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
break; break;
} }
} catch (ExecutionException e) { } catch (ExecutionException e) {
if (e.getCause() instanceof ExpiredIteratorException) {
throw (ExpiredIteratorException) e.getCause();
}
log.error("ExecutionException thrown while trying to get records", e); log.error("ExecutionException thrown while trying to get records", e);
} catch (InterruptedException e) { } catch (InterruptedException e) {
log.error("Thread was interrupted", e); log.error("Thread was interrupted", e);
break; break;
} }
} }
} finally {
futures.forEach(f -> f.cancel(true)); futures.forEach(f -> f.cancel(true));
}
return result; return result;
} }
@ -140,4 +146,9 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie
new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadNameFormat).build(), new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadNameFormat).build(),
new ThreadPoolExecutor.AbortPolicy()); new ThreadPoolExecutor.AbortPolicy());
} }
@Override
public KinesisDataFetcher getDataFetcher() {
return dataFetcher;
}
} }

View file

@ -44,4 +44,11 @@ public interface GetRecordsRetrievalStrategy {
* @return true if the strategy has been shutdown, false otherwise. * @return true if the strategy has been shutdown, false otherwise.
*/ */
boolean isShutdown(); boolean isShutdown();
/**
* Returns the KinesisDataFetcher used to getRecords from Kinesis.
*
* @return KinesisDataFetcher
*/
KinesisDataFetcher getDataFetcher();
} }

View file

@ -17,6 +17,7 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -27,6 +28,8 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber
import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.amazonaws.services.kinesis.model.ResourceNotFoundException; import com.amazonaws.services.kinesis.model.ResourceNotFoundException;
import com.amazonaws.services.kinesis.model.ShardIteratorType; import com.amazonaws.services.kinesis.model.ShardIteratorType;
import com.amazonaws.util.CollectionUtils;
import com.google.common.collect.Iterables;
import lombok.Data; import lombok.Data;
@ -42,6 +45,8 @@ class KinesisDataFetcher {
private final String shardId; private final String shardId;
private boolean isShardEndReached; private boolean isShardEndReached;
private boolean isInitialized; private boolean isInitialized;
private String lastKnownSequenceNumber;
private InitialPositionInStreamExtended initialPositionInStream;
/** /**
* *
@ -108,6 +113,9 @@ class KinesisDataFetcher {
@Override @Override
public GetRecordsResult accept() { public GetRecordsResult accept() {
nextIterator = result.getNextShardIterator(); nextIterator = result.getNextShardIterator();
if (!CollectionUtils.isNullOrEmpty(result.getRecords())) {
lastKnownSequenceNumber = Iterables.getLast(result.getRecords()).getSequenceNumber();
}
if (nextIterator == null) { if (nextIterator == null) {
isShardEndReached = true; isShardEndReached = true;
} }
@ -161,6 +169,8 @@ class KinesisDataFetcher {
if (nextIterator == null) { if (nextIterator == null) {
isShardEndReached = true; isShardEndReached = true;
} }
this.lastKnownSequenceNumber = sequenceNumber;
this.initialPositionInStream = initialPositionInStream;
} }
/** /**
@ -217,6 +227,17 @@ class KinesisDataFetcher {
return iterator; return iterator;
} }
/**
* Gets a new iterator from the last known sequence number i.e. the sequence number of the last record from the last
* getRecords call.
*/
public void restartIterator() {
if (StringUtils.isEmpty(lastKnownSequenceNumber) || initialPositionInStream == null) {
throw new IllegalStateException("Make sure to initialize the KinesisDataFetcher before restarting the iterator.");
}
advanceIteratorTo(lastKnownSequenceNumber, initialPositionInStream);
}
/** /**
* @return the shardEndReached * @return the shardEndReached
*/ */

View file

@ -23,10 +23,13 @@ import java.util.concurrent.LinkedBlockingQueue;
import org.apache.commons.lang.Validate; import org.apache.commons.lang.Validate;
import com.amazonaws.SdkClientException; import com.amazonaws.SdkClientException;
import com.amazonaws.services.cloudwatch.model.StandardUnit;
import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput; import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput;
import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper;
import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingFactory; import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingFactory;
import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory;
import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.GetRecordsResult;
import lombok.NonNull; import lombok.NonNull;
@ -42,6 +45,7 @@ import lombok.extern.apachecommons.CommonsLog;
*/ */
@CommonsLog @CommonsLog
public class PrefetchGetRecordsCache implements GetRecordsCache { public class PrefetchGetRecordsCache implements GetRecordsCache {
private static final String EXPIRED_ITERATOR_METRIC = "ExpiredIterator";
LinkedBlockingQueue<ProcessRecordsInput> getRecordsResultQueue; LinkedBlockingQueue<ProcessRecordsInput> getRecordsResultQueue;
private int maxPendingProcessRecordsInput; private int maxPendingProcessRecordsInput;
private int maxByteSize; private int maxByteSize;
@ -56,6 +60,8 @@ public class PrefetchGetRecordsCache implements GetRecordsCache {
private PrefetchCounters prefetchCounters; private PrefetchCounters prefetchCounters;
private boolean started = false; private boolean started = false;
private final String operation; private final String operation;
private final KinesisDataFetcher dataFetcher;
private final String shardId;
/** /**
* Constructor for the PrefetchGetRecordsCache. This cache prefetches records from Kinesis and stores them in a * Constructor for the PrefetchGetRecordsCache. This cache prefetches records from Kinesis and stores them in a
@ -76,9 +82,10 @@ public class PrefetchGetRecordsCache implements GetRecordsCache {
final int maxRecordsPerCall, final int maxRecordsPerCall,
@NonNull final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, @NonNull final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy,
@NonNull final ExecutorService executorService, @NonNull final ExecutorService executorService,
long idleMillisBetweenCalls, final long idleMillisBetweenCalls,
@NonNull final IMetricsFactory metricsFactory, @NonNull final IMetricsFactory metricsFactory,
@NonNull String operation) { @NonNull final String operation,
@NonNull final String shardId) {
this.getRecordsRetrievalStrategy = getRecordsRetrievalStrategy; this.getRecordsRetrievalStrategy = getRecordsRetrievalStrategy;
this.maxRecordsPerCall = maxRecordsPerCall; this.maxRecordsPerCall = maxRecordsPerCall;
this.maxPendingProcessRecordsInput = maxPendingProcessRecordsInput; this.maxPendingProcessRecordsInput = maxPendingProcessRecordsInput;
@ -92,6 +99,8 @@ public class PrefetchGetRecordsCache implements GetRecordsCache {
this.defaultGetRecordsCacheDaemon = new DefaultGetRecordsCacheDaemon(); this.defaultGetRecordsCacheDaemon = new DefaultGetRecordsCacheDaemon();
Validate.notEmpty(operation, "Operation cannot be empty"); Validate.notEmpty(operation, "Operation cannot be empty");
this.operation = operation; this.operation = operation;
this.dataFetcher = this.getRecordsRetrievalStrategy.getDataFetcher();
this.shardId = shardId;
} }
@Override @Override
@ -162,6 +171,14 @@ public class PrefetchGetRecordsCache implements GetRecordsCache {
prefetchCounters.added(processRecordsInput); prefetchCounters.added(processRecordsInput);
} catch (InterruptedException e) { } catch (InterruptedException e) {
log.info("Thread was interrupted, indicating shutdown was called on the cache."); log.info("Thread was interrupted, indicating shutdown was called on the cache.");
} catch (ExpiredIteratorException e) {
log.info(String.format("ShardId %s: getRecords threw ExpiredIteratorException - restarting"
+ " after greatest seqNum passed to customer", shardId), e);
MetricsHelper.getMetricsScope().addData(EXPIRED_ITERATOR_METRIC, 1, StandardUnit.Count,
MetricsLevel.SUMMARY);
dataFetcher.restartIterator();
} catch (SdkClientException e) { } catch (SdkClientException e) {
log.error("Exception thrown while fetching records from Kinesis", e); log.error("Exception thrown while fetching records from Kinesis", e);
} catch (Throwable e) { } catch (Throwable e) {

View file

@ -29,7 +29,8 @@ public interface RecordsFetcherFactory {
* *
* @return GetRecordsCache used to get records from Kinesis. * @return GetRecordsCache used to get records from Kinesis.
*/ */
GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, IMetricsFactory metricsFactory); GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId,
IMetricsFactory metricsFactory);
/** /**
* Sets the maximum number of ProcessRecordsInput objects the GetRecordsCache can hold, before further requests are * Sets the maximum number of ProcessRecordsInput objects the GetRecordsCache can hold, before further requests are

View file

@ -18,6 +18,7 @@ import java.util.concurrent.Executors;
import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import lombok.extern.apachecommons.CommonsLog; import lombok.extern.apachecommons.CommonsLog;
@CommonsLog @CommonsLog
@ -34,7 +35,8 @@ public class SimpleRecordsFetcherFactory implements RecordsFetcherFactory {
} }
@Override @Override
public GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, IMetricsFactory metricsFactory) { public GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId,
IMetricsFactory metricsFactory) {
if(dataFetchingStrategy.equals(DataFetchingStrategy.DEFAULT)) { if(dataFetchingStrategy.equals(DataFetchingStrategy.DEFAULT)) {
return new BlockingGetRecordsCache(maxRecords, getRecordsRetrievalStrategy); return new BlockingGetRecordsCache(maxRecords, getRecordsRetrievalStrategy);
} else { } else {
@ -46,7 +48,8 @@ public class SimpleRecordsFetcherFactory implements RecordsFetcherFactory {
.build()), .build()),
idleMillisBetweenCalls, idleMillisBetweenCalls,
metricsFactory, metricsFactory,
"ProcessTask"); "ProcessTask",
shardId);
} }
} }

View file

@ -42,4 +42,9 @@ public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetriev
public boolean isShutdown() { public boolean isShutdown() {
return false; return false;
} }
@Override
public KinesisDataFetcher getDataFetcher() {
return dataFetcher;
}
} }

View file

@ -38,17 +38,21 @@ import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Supplier; import java.util.function.Supplier;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import org.junit.After; import org.junit.After;
import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy;
import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.mockito.stubbing.Answer;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest { public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
@ -134,6 +138,24 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
assertNull(getRecordsResult); assertNull(getRecordsResult);
} }
@Test (expected = ExpiredIteratorException.class)
public void testExpiredIteratorExcpetion() throws InterruptedException {
when(dataFetcher.getRecords(eq(numberOfRecords))).thenAnswer(new Answer<DataFetcherResult>() {
@Override
public DataFetcherResult answer(final InvocationOnMock invocationOnMock) throws Throwable {
Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000);
throw new ExpiredIteratorException("ExpiredIterator");
}
});
try {
getRecordsRetrivalStrategy.getRecords(numberOfRecords);
} finally {
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords));
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
}
}
private int getLeastNumberOfCalls() { private int getLeastNumberOfCalls() {
int leastNumberOfCalls = 0; int leastNumberOfCalls = 0;
for (int i = MAX_POOL_SIZE; i > 0; i--) { for (int i = MAX_POOL_SIZE; i > 0; i--) {
@ -163,6 +185,7 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
} catch (InterruptedException e) { } catch (InterruptedException e) {
// Do nothing // Do nothing
} }
return result; return result;
} }
} }

View file

@ -20,20 +20,25 @@ import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.util.concurrent.CompletionService; import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Supplier; import java.util.function.Supplier;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
@ -154,4 +159,26 @@ public class AsynchronousGetRecordsRetrievalStrategyTest {
assertThat(actualResult, equalTo(expectedResults)); assertThat(actualResult, equalTo(expectedResults));
} }
@Test (expected = ExpiredIteratorException.class)
public void testExpiredIteratorExceptionCase() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionServiceSupplier, SHARD_ID);
Future<DataFetcherResult> successfulFuture2 = mock(Future.class);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(successfulFuture, successfulFuture2);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenThrow(new ExecutionException(new ExpiredIteratorException("ExpiredException")));
try {
strategy.getRecords(10);
} finally {
verify(executorService).isShutdown();
verify(completionService, times(2)).submit(any());
verify(completionService, times(2)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).cancel(eq(true));
verify(successfulFuture2).cancel(eq(true));
}
}
} }

View file

@ -32,6 +32,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
@ -274,6 +275,45 @@ public class KinesisDataFetcherTest {
verify(kinesisProxy, never()).get(anyString(), anyInt()); verify(kinesisProxy, never()).get(anyString(), anyInt());
} }
@Test
public void testRestartIterator() {
GetRecordsResult getRecordsResult = mock(GetRecordsResult.class);
GetRecordsResult restartGetRecordsResult = new GetRecordsResult();
Record record = mock(Record.class);
final String initialIterator = "InitialIterator";
final String nextShardIterator = "NextShardIterator";
final String restartShardIterator = "RestartIterator";
final String sequenceNumber = "SequenceNumber";
final String iteratorType = "AT_SEQUENCE_NUMBER";
KinesisProxy kinesisProxy = mock(KinesisProxy.class);
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesisProxy, SHARD_INFO);
when(kinesisProxy.getIterator(eq(SHARD_ID), eq(InitialPositionInStream.LATEST.toString()))).thenReturn(initialIterator);
when(kinesisProxy.get(eq(initialIterator), eq(10))).thenReturn(getRecordsResult);
when(getRecordsResult.getRecords()).thenReturn(Collections.singletonList(record));
when(getRecordsResult.getNextShardIterator()).thenReturn(nextShardIterator);
when(record.getSequenceNumber()).thenReturn(sequenceNumber);
fetcher.initialize(InitialPositionInStream.LATEST.toString(), INITIAL_POSITION_LATEST);
verify(kinesisProxy).getIterator(eq(SHARD_ID), eq(InitialPositionInStream.LATEST.toString()));
Assert.assertEquals(getRecordsResult, fetcher.getRecords(10).accept());
verify(kinesisProxy).get(eq(initialIterator), eq(10));
when(kinesisProxy.getIterator(eq(SHARD_ID), eq(iteratorType), eq(sequenceNumber))).thenReturn(restartShardIterator);
when(kinesisProxy.get(eq(restartShardIterator), eq(10))).thenReturn(restartGetRecordsResult);
fetcher.restartIterator();
Assert.assertEquals(restartGetRecordsResult, fetcher.getRecords(10).accept());
verify(kinesisProxy).getIterator(eq(SHARD_ID), eq(iteratorType), eq(sequenceNumber));
verify(kinesisProxy).get(eq(restartShardIterator), eq(10));
}
@Test (expected = IllegalStateException.class)
public void testRestartIteratorNotInitialized() {
KinesisDataFetcher dataFetcher = new KinesisDataFetcher(kinesisProxy, SHARD_INFO);
dataFetcher.restartIterator();
}
private DataFetcherResult assertAdvanced(KinesisDataFetcher dataFetcher, GetRecordsResult expectedResult, private DataFetcherResult assertAdvanced(KinesisDataFetcher dataFetcher, GetRecordsResult expectedResult,
String previousValue, String nextValue) { String previousValue, String nextValue) {
DataFetcherResult acceptResult = dataFetcher.getRecords(100); DataFetcherResult acceptResult = dataFetcher.getRecords(100);

View file

@ -20,6 +20,8 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -31,19 +33,19 @@ import java.util.List;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput;
import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory;
import com.amazonaws.services.kinesis.model.Record;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy;
import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput; import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput;
import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.amazonaws.services.kinesis.model.Record; import com.amazonaws.services.kinesis.model.Record;
@ -70,14 +72,13 @@ public class PrefetchGetRecordsCacheIntegrationTest {
@Mock @Mock
private IKinesisProxy proxy; private IKinesisProxy proxy;
@Mock @Mock
private ShardInfo shardInfo; private ShardInfo shardInfo;
@Before @Before
public void setup() { public void setup() {
records = new ArrayList<>(); records = new ArrayList<>();
dataFetcher = new KinesisDataFetcherForTest(proxy, shardInfo); dataFetcher = spy(new KinesisDataFetcherForTest(proxy, shardInfo));
getRecordsRetrievalStrategy = spy(new SynchronousGetRecordsRetrievalStrategy(dataFetcher)); getRecordsRetrievalStrategy = spy(new SynchronousGetRecordsRetrievalStrategy(dataFetcher));
executorService = spy(Executors.newFixedThreadPool(1)); executorService = spy(Executors.newFixedThreadPool(1));
@ -89,7 +90,8 @@ public class PrefetchGetRecordsCacheIntegrationTest {
executorService, executorService,
IDLE_MILLIS_BETWEEN_CALLS, IDLE_MILLIS_BETWEEN_CALLS,
new NullMetricsFactory(), new NullMetricsFactory(),
operation); operation,
"test-shard");
} }
@Test @Test
@ -135,7 +137,8 @@ public class PrefetchGetRecordsCacheIntegrationTest {
executorService2, executorService2,
IDLE_MILLIS_BETWEEN_CALLS, IDLE_MILLIS_BETWEEN_CALLS,
new NullMetricsFactory(), new NullMetricsFactory(),
operation); operation,
"test-shard-2");
getRecordsCache.start(); getRecordsCache.start();
sleep(IDLE_MILLIS_BETWEEN_CALLS); sleep(IDLE_MILLIS_BETWEEN_CALLS);
@ -167,6 +170,26 @@ public class PrefetchGetRecordsCacheIntegrationTest {
verify(getRecordsRetrievalStrategy2).shutdown(); verify(getRecordsRetrievalStrategy2).shutdown();
} }
@Test
public void testExpiredIteratorException() {
when(dataFetcher.getRecords(eq(MAX_RECORDS_PER_CALL))).thenAnswer(new Answer<DataFetcherResult>() {
@Override
public DataFetcherResult answer(final InvocationOnMock invocationOnMock) throws Throwable {
throw new ExpiredIteratorException("ExpiredIterator");
}
}).thenCallRealMethod();
doNothing().when(dataFetcher).restartIterator();
getRecordsCache.start();
sleep(IDLE_MILLIS_BETWEEN_CALLS);
ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult();
assertNotNull(processRecordsInput);
assertTrue(processRecordsInput.getRecords().isEmpty());
verify(dataFetcher).restartIterator();
}
@After @After
public void shutdown() { public void shutdown() {
getRecordsCache.shutdown(); getRecordsCache.shutdown();

View file

@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -36,7 +37,6 @@ import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -45,6 +45,8 @@ import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput; import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput;
import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.amazonaws.services.kinesis.model.Record; import com.amazonaws.services.kinesis.model.Record;
@ -66,6 +68,8 @@ public class PrefetchGetRecordsCacheTest {
private GetRecordsResult getRecordsResult; private GetRecordsResult getRecordsResult;
@Mock @Mock
private Record record; private Record record;
@Mock
private KinesisDataFetcher dataFetcher;
private List<Record> records; private List<Record> records;
private ExecutorService executorService; private ExecutorService executorService;
@ -75,6 +79,8 @@ public class PrefetchGetRecordsCacheTest {
@Before @Before
public void setup() { public void setup() {
when(getRecordsRetrievalStrategy.getDataFetcher()).thenReturn(dataFetcher);
executorService = spy(Executors.newFixedThreadPool(1)); executorService = spy(Executors.newFixedThreadPool(1));
getRecordsCache = new PrefetchGetRecordsCache( getRecordsCache = new PrefetchGetRecordsCache(
MAX_SIZE, MAX_SIZE,
@ -85,7 +91,8 @@ public class PrefetchGetRecordsCacheTest {
executorService, executorService,
IDLE_MILLIS_BETWEEN_CALLS, IDLE_MILLIS_BETWEEN_CALLS,
new NullMetricsFactory(), new NullMetricsFactory(),
operation); operation,
"shardId");
spyQueue = spy(getRecordsCache.getRecordsResultQueue); spyQueue = spy(getRecordsCache.getRecordsResultQueue);
records = spy(new ArrayList<>()); records = spy(new ArrayList<>());
@ -195,6 +202,20 @@ public class PrefetchGetRecordsCacheTest {
getRecordsCache.getNextResult(); getRecordsCache.getNextResult();
} }
@Test
public void testExpiredIteratorException() {
getRecordsCache.start();
when(getRecordsRetrievalStrategy.getRecords(MAX_RECORDS_PER_CALL)).thenThrow(ExpiredIteratorException.class).thenReturn(getRecordsResult);
doNothing().when(dataFetcher).restartIterator();
getRecordsCache.getNextResult();
sleep(1000);
verify(dataFetcher).restartIterator();
}
@After @After
public void shutdown() { public void shutdown() {
getRecordsCache.shutdown(); getRecordsCache.shutdown();

View file

@ -1,7 +1,5 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker; package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory;
import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
@ -10,13 +8,14 @@ import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory;
public class RecordsFetcherFactoryTest { public class RecordsFetcherFactoryTest {
private String shardId = "TestShard"; private String shardId = "TestShard";
private RecordsFetcherFactory recordsFetcherFactory; private RecordsFetcherFactory recordsFetcherFactory;
@Mock @Mock
private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy; private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
@Mock @Mock
private IMetricsFactory metricsFactory; private IMetricsFactory metricsFactory;

View file

@ -343,7 +343,9 @@ public class ShardConsumerTest {
getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords,
new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); new SynchronousGetRecordsRetrievalStrategy(dataFetcher)));
when(recordsFetcherFactory.createRecordsFetcher(any(), anyString(),any())).thenReturn(getRecordsCache); when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(),
any(IMetricsFactory.class)))
.thenReturn(getRecordsCache);
ShardConsumer consumer = ShardConsumer consumer =
new ShardConsumer(shardInfo, new ShardConsumer(shardInfo,
@ -472,7 +474,9 @@ public class ShardConsumerTest {
getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords,
new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); new SynchronousGetRecordsRetrievalStrategy(dataFetcher)));
when(recordsFetcherFactory.createRecordsFetcher(any(), anyString(),any())).thenReturn(getRecordsCache); when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(),
any(IMetricsFactory.class)))
.thenReturn(getRecordsCache);
ShardConsumer consumer = ShardConsumer consumer =
new ShardConsumer(shardInfo, new ShardConsumer(shardInfo,

View file

@ -621,7 +621,9 @@ public class WorkerTest {
RecordsFetcherFactory recordsFetcherFactory = mock(RecordsFetcherFactory.class); RecordsFetcherFactory recordsFetcherFactory = mock(RecordsFetcherFactory.class);
GetRecordsCache getRecordsCache = mock(GetRecordsCache.class); GetRecordsCache getRecordsCache = mock(GetRecordsCache.class);
when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory);
when(recordsFetcherFactory.createRecordsFetcher(any(), anyString(),any())).thenReturn(getRecordsCache); when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(),
any(IMetricsFactory.class)))
.thenReturn(getRecordsCache);
when(getRecordsCache.getNextResult()).thenReturn(new ProcessRecordsInput().withRecords(Collections.emptyList()).withMillisBehindLatest(0L)); when(getRecordsCache.getNextResult()).thenReturn(new ProcessRecordsInput().withRecords(Collections.emptyList()).withMillisBehindLatest(0L));
WorkerThread workerThread = runWorker(shardList, WorkerThread workerThread = runWorker(shardList,