diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java index fae780f5..f24fc574 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java @@ -477,7 +477,7 @@ public class KinesisClientLibConfiguration { InitialPositionInStreamExtended.newInitialPosition(initialPositionInStream); this.skipShardSyncAtWorkerInitializationIfLeasesExist = DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST; this.shardPrioritization = DEFAULT_SHARD_PRIORITIZATION; - this.recordsFetcherFactory = new SimpleRecordsFetcherFactory(this.maxRecords); + this.recordsFetcherFactory = new SimpleRecordsFetcherFactory(); } /** diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactory.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactory.java index afc6c4f2..c1a513a9 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactory.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactory.java @@ -26,11 +26,12 @@ public interface RecordsFetcherFactory { * @param getRecordsRetrievalStrategy GetRecordsRetrievalStrategy to be used with the GetRecordsCache * @param shardId ShardId of the shard that the fetcher will retrieve records for * @param metricsFactory MetricsFactory used to create metricScope + * @param maxRecords Max number of records to be returned in a single get call * * @return GetRecordsCache used to get records from Kinesis. */ GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, - IMetricsFactory metricsFactory); + IMetricsFactory metricsFactory, int maxRecords); /** * Sets the maximum number of ProcessRecordsInput objects the GetRecordsCache can hold, before further requests are diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java index 14a1d08c..d8aa88d1 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java @@ -235,7 +235,7 @@ class ShardConsumer { this.dataFetcher = kinesisDataFetcher; this.getRecordsCache = config.getRecordsFetcherFactory().createRecordsFetcher( makeStrategy(this.dataFetcher, retryGetRecordsInSeconds, maxGetRecordsThreadPool, this.shardInfo), - this.getShardInfo().getShardId(), this.metricsFactory); + this.getShardInfo().getShardId(), this.metricsFactory, this.config.getMaxRecords()); } /** diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SimpleRecordsFetcherFactory.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SimpleRecordsFetcherFactory.java index bd33fd98..79ad9f55 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SimpleRecordsFetcherFactory.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/SimpleRecordsFetcherFactory.java @@ -23,20 +23,15 @@ import lombok.extern.apachecommons.CommonsLog; @CommonsLog public class SimpleRecordsFetcherFactory implements RecordsFetcherFactory { - private final int maxRecords; private int maxPendingProcessRecordsInput = 3; private int maxByteSize = 8 * 1024 * 1024; private int maxRecordsCount = 30000; private long idleMillisBetweenCalls = 1500L; private DataFetchingStrategy dataFetchingStrategy = DataFetchingStrategy.DEFAULT; - - public SimpleRecordsFetcherFactory(int maxRecords) { - this.maxRecords = maxRecords; - } - + @Override public GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, - IMetricsFactory metricsFactory) { + IMetricsFactory metricsFactory, int maxRecords) { if(dataFetchingStrategy.equals(DataFetchingStrategy.DEFAULT)) { return new BlockingGetRecordsCache(maxRecords, getRecordsRetrievalStrategy); } else { diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactoryTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactoryTest.java index 7107d0fd..d686c914 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactoryTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordsFetcherFactoryTest.java @@ -22,13 +22,13 @@ public class RecordsFetcherFactoryTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - recordsFetcherFactory = new SimpleRecordsFetcherFactory(1); + recordsFetcherFactory = new SimpleRecordsFetcherFactory(); } @Test public void createDefaultRecordsFetcherTest() { GetRecordsCache recordsCache = recordsFetcherFactory.createRecordsFetcher(getRecordsRetrievalStrategy, shardId, - metricsFactory); + metricsFactory, 1); assertThat(recordsCache, instanceOf(BlockingGetRecordsCache.class)); } @@ -36,7 +36,7 @@ public class RecordsFetcherFactoryTest { public void createPrefetchRecordsFetcherTest() { recordsFetcherFactory.setDataFetchingStrategy(DataFetchingStrategy.PREFETCH_CACHED); GetRecordsCache recordsCache = recordsFetcherFactory.createRecordsFetcher(getRecordsRetrievalStrategy, shardId, - metricsFactory); + metricsFactory, 1); assertThat(recordsCache, instanceOf(PrefetchGetRecordsCache.class)); } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java index f235ca93..516788c7 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.argThat; import static org.mockito.Mockito.atLeastOnce; @@ -97,7 +98,6 @@ public class ShardConsumerTest { // Use Executors.newFixedThreadPool since it returns ThreadPoolExecutor, which is // ... a non-final public class, and so can be mocked and spied. private final ExecutorService executorService = Executors.newFixedThreadPool(1); - private final int maxRecords = 500; private RecordsFetcherFactory recordsFetcherFactory; private GetRecordsCache getRecordsCache; @@ -119,7 +119,7 @@ public class ShardConsumerTest { public void setup() { getRecordsCache = null; - recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory(maxRecords)); + recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory()); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); when(config.getLogWarningForTaskAfterMillis()).thenReturn(Optional.empty()); } @@ -344,7 +344,7 @@ public class ShardConsumerTest { getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), - any(IMetricsFactory.class))) + any(IMetricsFactory.class), anyInt())) .thenReturn(getRecordsCache); ShardConsumer consumer = @@ -475,7 +475,7 @@ public class ShardConsumerTest { getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), - any(IMetricsFactory.class))) + any(IMetricsFactory.class), anyInt())) .thenReturn(getRecordsCache); ShardConsumer consumer = @@ -571,7 +571,7 @@ public class ShardConsumerTest { final ExtendedSequenceNumber checkpointSequenceNumber = new ExtendedSequenceNumber("123"); final ExtendedSequenceNumber pendingCheckpointSequenceNumber = new ExtendedSequenceNumber("999"); when(leaseManager.getLease(anyString())).thenReturn(null); - when(config.getRecordsFetcherFactory()).thenReturn(new SimpleRecordsFetcherFactory(2)); + when(config.getRecordsFetcherFactory()).thenReturn(new SimpleRecordsFetcherFactory()); when(checkpoint.getCheckpointObject(anyString())).thenReturn( new Checkpoint(checkpointSequenceNumber, pendingCheckpointSequenceNumber)); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java index fd3382a3..ce406dce 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; @@ -172,7 +173,7 @@ public class WorkerTest { @Before public void setup() { config = spy(new KinesisClientLibConfiguration("app", null, null, null)); - recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory(500)); + recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory()); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); } @@ -505,7 +506,7 @@ public class WorkerTest { lease.setCheckpoint(new ExtendedSequenceNumber("2")); initialLeases.add(lease); boolean callProcessRecordsForEmptyRecordList = true; - RecordsFetcherFactory recordsFetcherFactory = new SimpleRecordsFetcherFactory(500); + RecordsFetcherFactory recordsFetcherFactory = new SimpleRecordsFetcherFactory(); recordsFetcherFactory.setIdleMillisBetweenCalls(0L); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); runAndTestWorker(shardList, threadPoolSize, initialLeases, callProcessRecordsForEmptyRecordList, numberOfRecordsPerShard, config); @@ -622,7 +623,7 @@ public class WorkerTest { GetRecordsCache getRecordsCache = mock(GetRecordsCache.class); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), - any(IMetricsFactory.class))) + any(IMetricsFactory.class), anyInt())) .thenReturn(getRecordsCache); when(getRecordsCache.getNextResult()).thenReturn(new ProcessRecordsInput().withRecords(Collections.emptyList()).withMillisBehindLatest(0L));