diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java index 2e3cbd9e..2db74fba5 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategy.java @@ -16,7 +16,6 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; import java.util.HashSet; import java.util.Set; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; @@ -31,6 +30,7 @@ import java.util.function.Supplier; import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingScope; +import com.amazonaws.services.kinesis.model.ExpiredIteratorException; import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -81,33 +81,39 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie CompletionService completionService = completionServiceSupplier.get(); Set> futures = new HashSet<>(); Callable retrieverCall = createRetrieverCallable(maxRecords); - while (true) { - try { - futures.add(completionService.submit(retrieverCall)); - } catch (RejectedExecutionException e) { - log.warn("Out of resources, unable to start additional requests."); - } + try { + while (true) { + try { + futures.add(completionService.submit(retrieverCall)); + } catch (RejectedExecutionException e) { + log.warn("Out of resources, unable to start additional requests."); + } - try { - Future resultFuture = completionService.poll(retryGetRecordsInSeconds, - TimeUnit.SECONDS); - if (resultFuture != null) { - // - // Fix to ensure that we only let the shard iterator advance when we intend to return the result - // to the caller. This ensures that the shard iterator is consistently advance in step with - // what the caller sees. - // - result = resultFuture.get().accept(); + try { + Future resultFuture = completionService.poll(retryGetRecordsInSeconds, + TimeUnit.SECONDS); + if (resultFuture != null) { + // + // Fix to ensure that we only let the shard iterator advance when we intend to return the result + // to the caller. This ensures that the shard iterator is consistently advance in step with + // what the caller sees. + // + result = resultFuture.get().accept(); + break; + } + } catch (ExecutionException e) { + if (e.getCause() instanceof ExpiredIteratorException) { + throw (ExpiredIteratorException) e.getCause(); + } + log.error("ExecutionException thrown while trying to get records", e); + } catch (InterruptedException e) { + log.error("Thread was interrupted", e); break; } - } catch (ExecutionException e) { - log.error("ExecutionException thrown while trying to get records", e); - } catch (InterruptedException e) { - log.error("Thread was interrupted", e); - break; } + } finally { + futures.forEach(f -> f.cancel(true)); } - futures.forEach(f -> f.cancel(true)); return result; } @@ -140,4 +146,9 @@ public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrie new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadNameFormat).build(), new ThreadPoolExecutor.AbortPolicy()); } + + @Override + public KinesisDataFetcher getDataFetcher() { + return dataFetcher; + } } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java index 8f7afe25..4f474887 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/GetRecordsRetrievalStrategy.java @@ -44,4 +44,11 @@ public interface GetRecordsRetrievalStrategy { * @return true if the strategy has been shutdown, false otherwise. */ boolean isShutdown(); + + /** + * Returns the KinesisDataFetcher used to getRecords from Kinesis. + * + * @return KinesisDataFetcher + */ + KinesisDataFetcher getDataFetcher(); } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcher.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcher.java index c2ba9d15..0bd4bee3 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcher.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcher.java @@ -17,6 +17,7 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; import java.util.Collections; import java.util.Date; +import org.apache.commons.lang.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -27,6 +28,8 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.ResourceNotFoundException; import com.amazonaws.services.kinesis.model.ShardIteratorType; +import com.amazonaws.util.CollectionUtils; +import com.google.common.collect.Iterables; import lombok.Data; @@ -42,6 +45,8 @@ class KinesisDataFetcher { private final String shardId; private boolean isShardEndReached; private boolean isInitialized; + private String lastKnownSequenceNumber; + private InitialPositionInStreamExtended initialPositionInStream; /** * @@ -108,6 +113,9 @@ class KinesisDataFetcher { @Override public GetRecordsResult accept() { nextIterator = result.getNextShardIterator(); + if (!CollectionUtils.isNullOrEmpty(result.getRecords())) { + lastKnownSequenceNumber = Iterables.getLast(result.getRecords()).getSequenceNumber(); + } if (nextIterator == null) { isShardEndReached = true; } @@ -161,6 +169,8 @@ class KinesisDataFetcher { if (nextIterator == null) { isShardEndReached = true; } + this.lastKnownSequenceNumber = sequenceNumber; + this.initialPositionInStream = initialPositionInStream; } /** @@ -217,6 +227,17 @@ class KinesisDataFetcher { return iterator; } + /** + * Gets a new iterator from the last known sequence number i.e. the sequence number of the last record from the last + * getRecords call. + */ + public void restartIterator() { + if (StringUtils.isEmpty(lastKnownSequenceNumber) || initialPositionInStream == null) { + throw new IllegalStateException("Make sure to initialize the KinesisDataFetcher before restarting the iterator."); + } + advanceIteratorTo(lastKnownSequenceNumber, initialPositionInStream); + } + /** * @return the shardEndReached */ diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCache.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCache.java index 06e77c8c..982d70cc 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCache.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCache.java @@ -23,10 +23,13 @@ import java.util.concurrent.LinkedBlockingQueue; import org.apache.commons.lang.Validate; import com.amazonaws.SdkClientException; +import com.amazonaws.services.cloudwatch.model.StandardUnit; import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput; import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingFactory; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; +import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel; +import com.amazonaws.services.kinesis.model.ExpiredIteratorException; import com.amazonaws.services.kinesis.model.GetRecordsResult; import lombok.NonNull; @@ -42,6 +45,7 @@ import lombok.extern.apachecommons.CommonsLog; */ @CommonsLog public class PrefetchGetRecordsCache implements GetRecordsCache { + private static final String EXPIRED_ITERATOR_METRIC = "ExpiredIterator"; LinkedBlockingQueue getRecordsResultQueue; private int maxPendingProcessRecordsInput; private int maxByteSize; @@ -56,6 +60,8 @@ public class PrefetchGetRecordsCache implements GetRecordsCache { private PrefetchCounters prefetchCounters; private boolean started = false; private final String operation; + private final KinesisDataFetcher dataFetcher; + private final String shardId; /** * Constructor for the PrefetchGetRecordsCache. This cache prefetches records from Kinesis and stores them in a @@ -76,9 +82,10 @@ public class PrefetchGetRecordsCache implements GetRecordsCache { final int maxRecordsPerCall, @NonNull final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, @NonNull final ExecutorService executorService, - long idleMillisBetweenCalls, + final long idleMillisBetweenCalls, @NonNull final IMetricsFactory metricsFactory, - @NonNull String operation) { + @NonNull final String operation, + @NonNull final String shardId) { this.getRecordsRetrievalStrategy = getRecordsRetrievalStrategy; this.maxRecordsPerCall = maxRecordsPerCall; this.maxPendingProcessRecordsInput = maxPendingProcessRecordsInput; @@ -92,6 +99,8 @@ public class PrefetchGetRecordsCache implements GetRecordsCache { this.defaultGetRecordsCacheDaemon = new DefaultGetRecordsCacheDaemon(); Validate.notEmpty(operation, "Operation cannot be empty"); this.operation = operation; + this.dataFetcher = this.getRecordsRetrievalStrategy.getDataFetcher(); + this.shardId = shardId; } @Override @@ -162,6 +171,14 @@ public class PrefetchGetRecordsCache implements GetRecordsCache { prefetchCounters.added(processRecordsInput); } catch (InterruptedException e) { log.info("Thread was interrupted, indicating shutdown was called on the cache."); + } catch (ExpiredIteratorException e) { + log.info(String.format("ShardId %s: getRecords threw ExpiredIteratorException - restarting" + + " after greatest seqNum passed to customer", shardId), e); + + MetricsHelper.getMetricsScope().addData(EXPIRED_ITERATOR_METRIC, 1, StandardUnit.Count, + MetricsLevel.SUMMARY); + + dataFetcher.restartIterator(); } catch (SdkClientException e) { log.error("Exception thrown while fetching records from Kinesis", e); } catch (Throwable e) { diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactory.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactory.java index be8316d7..afc6c4f2 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactory.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactory.java @@ -29,7 +29,8 @@ public interface RecordsFetcherFactory { * * @return GetRecordsCache used to get records from Kinesis. */ - GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, IMetricsFactory metricsFactory); + GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, + IMetricsFactory metricsFactory); /** * Sets the maximum number of ProcessRecordsInput objects the GetRecordsCache can hold, before further requests are diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SimpleRecordsFetcherFactory.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SimpleRecordsFetcherFactory.java index 44c93e7b..bd33fd98 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SimpleRecordsFetcherFactory.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SimpleRecordsFetcherFactory.java @@ -18,6 +18,7 @@ import java.util.concurrent.Executors; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; import com.google.common.util.concurrent.ThreadFactoryBuilder; + import lombok.extern.apachecommons.CommonsLog; @CommonsLog @@ -34,7 +35,8 @@ public class SimpleRecordsFetcherFactory implements RecordsFetcherFactory { } @Override - public GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, IMetricsFactory metricsFactory) { + public GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, + IMetricsFactory metricsFactory) { if(dataFetchingStrategy.equals(DataFetchingStrategy.DEFAULT)) { return new BlockingGetRecordsCache(maxRecords, getRecordsRetrievalStrategy); } else { @@ -46,7 +48,8 @@ public class SimpleRecordsFetcherFactory implements RecordsFetcherFactory { .build()), idleMillisBetweenCalls, metricsFactory, - "ProcessTask"); + "ProcessTask", + shardId); } } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java index c862c348..f4209189 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SynchronousGetRecordsRetrievalStrategy.java @@ -42,4 +42,9 @@ public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetriev public boolean isShutdown() { return false; } + + @Override + public KinesisDataFetcher getDataFetcher() { + return dataFetcher; + } } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java index 30b877e8..37f58c1c 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyIntegrationTest.java @@ -38,17 +38,21 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import com.amazonaws.services.kinesis.model.ExpiredIteratorException; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.mockito.stubbing.Answer; @RunWith(MockitoJUnitRunner.class) public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest { @@ -133,6 +137,24 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest { verify(mockFuture).get(); assertNull(getRecordsResult); } + + @Test (expected = ExpiredIteratorException.class) + public void testExpiredIteratorExcpetion() throws InterruptedException { + when(dataFetcher.getRecords(eq(numberOfRecords))).thenAnswer(new Answer() { + @Override + public DataFetcherResult answer(final InvocationOnMock invocationOnMock) throws Throwable { + Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000); + throw new ExpiredIteratorException("ExpiredIterator"); + } + }); + + try { + getRecordsRetrivalStrategy.getRecords(numberOfRecords); + } finally { + verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords)); + verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any()); + } + } private int getLeastNumberOfCalls() { int leastNumberOfCalls = 0; @@ -163,6 +185,7 @@ public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest { } catch (InterruptedException e) { // Do nothing } + return result; } } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java index aa9e9a24..151300de 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/AsynchronousGetRecordsRetrievalStrategyTest.java @@ -20,20 +20,25 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import com.amazonaws.services.kinesis.model.ExpiredIteratorException; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; @@ -153,5 +158,27 @@ public class AsynchronousGetRecordsRetrievalStrategyTest { assertThat(actualResult, equalTo(expectedResults)); } + + @Test (expected = ExpiredIteratorException.class) + public void testExpiredIteratorExceptionCase() throws Exception { + AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, + executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionServiceSupplier, SHARD_ID); + Future successfulFuture2 = mock(Future.class); + + when(executorService.isShutdown()).thenReturn(false); + when(completionService.submit(any())).thenReturn(successfulFuture, successfulFuture2); + when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture); + when(successfulFuture.get()).thenThrow(new ExecutionException(new ExpiredIteratorException("ExpiredException"))); + + try { + strategy.getRecords(10); + } finally { + verify(executorService).isShutdown(); + verify(completionService, times(2)).submit(any()); + verify(completionService, times(2)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS)); + verify(successfulFuture).cancel(eq(true)); + verify(successfulFuture2).cancel(eq(true)); + } + } } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java index 6648b919..fbe720ae 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisDataFetcherTest.java @@ -32,6 +32,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.ArrayList; +import java.util.Collections; import java.util.Date; import java.util.List; @@ -273,6 +274,45 @@ public class KinesisDataFetcherTest { verify(kinesisProxy, never()).get(anyString(), anyInt()); } + + @Test + public void testRestartIterator() { + GetRecordsResult getRecordsResult = mock(GetRecordsResult.class); + GetRecordsResult restartGetRecordsResult = new GetRecordsResult(); + Record record = mock(Record.class); + final String initialIterator = "InitialIterator"; + final String nextShardIterator = "NextShardIterator"; + final String restartShardIterator = "RestartIterator"; + final String sequenceNumber = "SequenceNumber"; + final String iteratorType = "AT_SEQUENCE_NUMBER"; + KinesisProxy kinesisProxy = mock(KinesisProxy.class); + KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesisProxy, SHARD_INFO); + + when(kinesisProxy.getIterator(eq(SHARD_ID), eq(InitialPositionInStream.LATEST.toString()))).thenReturn(initialIterator); + when(kinesisProxy.get(eq(initialIterator), eq(10))).thenReturn(getRecordsResult); + when(getRecordsResult.getRecords()).thenReturn(Collections.singletonList(record)); + when(getRecordsResult.getNextShardIterator()).thenReturn(nextShardIterator); + when(record.getSequenceNumber()).thenReturn(sequenceNumber); + + fetcher.initialize(InitialPositionInStream.LATEST.toString(), INITIAL_POSITION_LATEST); + verify(kinesisProxy).getIterator(eq(SHARD_ID), eq(InitialPositionInStream.LATEST.toString())); + Assert.assertEquals(getRecordsResult, fetcher.getRecords(10).accept()); + verify(kinesisProxy).get(eq(initialIterator), eq(10)); + + when(kinesisProxy.getIterator(eq(SHARD_ID), eq(iteratorType), eq(sequenceNumber))).thenReturn(restartShardIterator); + when(kinesisProxy.get(eq(restartShardIterator), eq(10))).thenReturn(restartGetRecordsResult); + + fetcher.restartIterator(); + Assert.assertEquals(restartGetRecordsResult, fetcher.getRecords(10).accept()); + verify(kinesisProxy).getIterator(eq(SHARD_ID), eq(iteratorType), eq(sequenceNumber)); + verify(kinesisProxy).get(eq(restartShardIterator), eq(10)); + } + + @Test (expected = IllegalStateException.class) + public void testRestartIteratorNotInitialized() { + KinesisDataFetcher dataFetcher = new KinesisDataFetcher(kinesisProxy, SHARD_INFO); + dataFetcher.restartIterator(); + } private DataFetcherResult assertAdvanced(KinesisDataFetcher dataFetcher, GetRecordsResult expectedResult, String previousValue, String nextValue) { diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCacheIntegrationTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCacheIntegrationTest.java index 37d0e446..e24d5bb0 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCacheIntegrationTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCacheIntegrationTest.java @@ -20,6 +20,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -31,19 +33,19 @@ import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput; -import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory; -import com.amazonaws.services.kinesis.model.Record; - import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy; import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput; +import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory; +import com.amazonaws.services.kinesis.model.ExpiredIteratorException; import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.Record; @@ -70,14 +72,13 @@ public class PrefetchGetRecordsCacheIntegrationTest { @Mock private IKinesisProxy proxy; - @Mock private ShardInfo shardInfo; @Before public void setup() { records = new ArrayList<>(); - dataFetcher = new KinesisDataFetcherForTest(proxy, shardInfo); + dataFetcher = spy(new KinesisDataFetcherForTest(proxy, shardInfo)); getRecordsRetrievalStrategy = spy(new SynchronousGetRecordsRetrievalStrategy(dataFetcher)); executorService = spy(Executors.newFixedThreadPool(1)); @@ -89,7 +90,8 @@ public class PrefetchGetRecordsCacheIntegrationTest { executorService, IDLE_MILLIS_BETWEEN_CALLS, new NullMetricsFactory(), - operation); + operation, + "test-shard"); } @Test @@ -135,7 +137,8 @@ public class PrefetchGetRecordsCacheIntegrationTest { executorService2, IDLE_MILLIS_BETWEEN_CALLS, new NullMetricsFactory(), - operation); + operation, + "test-shard-2"); getRecordsCache.start(); sleep(IDLE_MILLIS_BETWEEN_CALLS); @@ -167,6 +170,26 @@ public class PrefetchGetRecordsCacheIntegrationTest { verify(getRecordsRetrievalStrategy2).shutdown(); } + @Test + public void testExpiredIteratorException() { + when(dataFetcher.getRecords(eq(MAX_RECORDS_PER_CALL))).thenAnswer(new Answer() { + @Override + public DataFetcherResult answer(final InvocationOnMock invocationOnMock) throws Throwable { + throw new ExpiredIteratorException("ExpiredIterator"); + } + }).thenCallRealMethod(); + doNothing().when(dataFetcher).restartIterator(); + + getRecordsCache.start(); + sleep(IDLE_MILLIS_BETWEEN_CALLS); + + ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult(); + + assertNotNull(processRecordsInput); + assertTrue(processRecordsInput.getRecords().isEmpty()); + verify(dataFetcher).restartIterator(); + } + @After public void shutdown() { getRecordsCache.shutdown(); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCacheTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCacheTest.java index 6091baa9..2b650866 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCacheTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/PrefetchGetRecordsCacheTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -36,7 +37,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.stream.IntStream; -import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -45,6 +45,8 @@ import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput; +import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory; +import com.amazonaws.services.kinesis.model.ExpiredIteratorException; import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.Record; @@ -66,6 +68,8 @@ public class PrefetchGetRecordsCacheTest { private GetRecordsResult getRecordsResult; @Mock private Record record; + @Mock + private KinesisDataFetcher dataFetcher; private List records; private ExecutorService executorService; @@ -75,6 +79,8 @@ public class PrefetchGetRecordsCacheTest { @Before public void setup() { + when(getRecordsRetrievalStrategy.getDataFetcher()).thenReturn(dataFetcher); + executorService = spy(Executors.newFixedThreadPool(1)); getRecordsCache = new PrefetchGetRecordsCache( MAX_SIZE, @@ -85,7 +91,8 @@ public class PrefetchGetRecordsCacheTest { executorService, IDLE_MILLIS_BETWEEN_CALLS, new NullMetricsFactory(), - operation); + operation, + "shardId"); spyQueue = spy(getRecordsCache.getRecordsResultQueue); records = spy(new ArrayList<>()); @@ -194,6 +201,20 @@ public class PrefetchGetRecordsCacheTest { when(executorService.isShutdown()).thenReturn(true); getRecordsCache.getNextResult(); } + + @Test + public void testExpiredIteratorException() { + getRecordsCache.start(); + + when(getRecordsRetrievalStrategy.getRecords(MAX_RECORDS_PER_CALL)).thenThrow(ExpiredIteratorException.class).thenReturn(getRecordsResult); + doNothing().when(dataFetcher).restartIterator(); + + getRecordsCache.getNextResult(); + + sleep(1000); + + verify(dataFetcher).restartIterator(); + } @After public void shutdown() { diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactoryTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactoryTest.java index 912804da..7107d0fd 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactoryTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactoryTest.java @@ -1,7 +1,5 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; -import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; - import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; @@ -10,13 +8,14 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; + public class RecordsFetcherFactoryTest { private String shardId = "TestShard"; private RecordsFetcherFactory recordsFetcherFactory; @Mock private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy; - @Mock private IMetricsFactory metricsFactory; diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java index 9a7f2234..f235ca93 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java @@ -343,7 +343,9 @@ public class ShardConsumerTest { getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); - when(recordsFetcherFactory.createRecordsFetcher(any(), anyString(),any())).thenReturn(getRecordsCache); + when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), + any(IMetricsFactory.class))) + .thenReturn(getRecordsCache); ShardConsumer consumer = new ShardConsumer(shardInfo, @@ -472,7 +474,9 @@ public class ShardConsumerTest { getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); - when(recordsFetcherFactory.createRecordsFetcher(any(), anyString(),any())).thenReturn(getRecordsCache); + when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), + any(IMetricsFactory.class))) + .thenReturn(getRecordsCache); ShardConsumer consumer = new ShardConsumer(shardInfo, diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java index a8856a0b..fd3382a3 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java @@ -621,7 +621,9 @@ public class WorkerTest { RecordsFetcherFactory recordsFetcherFactory = mock(RecordsFetcherFactory.class); GetRecordsCache getRecordsCache = mock(GetRecordsCache.class); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); - when(recordsFetcherFactory.createRecordsFetcher(any(), anyString(),any())).thenReturn(getRecordsCache); + when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), + any(IMetricsFactory.class))) + .thenReturn(getRecordsCache); when(getRecordsCache.getNextResult()).thenReturn(new ProcessRecordsInput().withRecords(Collections.emptyList()).withMillisBehindLatest(0L)); WorkerThread workerThread = runWorker(shardList,