Correctly Send MaxRecords to SingleRecordsFetcherFactory Fixed #262

Fixes #262 

Changing the signture of SingleRecordsFetcherFactory to no longer take maxRecords as the parameter to the constructor. Changed the createRecordsFetcher signature to take maxRecords as a parameter. (#264)
This commit is contained in:
Sahil Palvia 2017-11-10 06:32:16 -08:00 committed by Justin Pfifer
parent 5c3ff2b31e
commit 1abb41dbdb
7 changed files with 18 additions and 21 deletions

View file

@ -477,7 +477,7 @@ public class KinesisClientLibConfiguration {
InitialPositionInStreamExtended.newInitialPosition(initialPositionInStream); InitialPositionInStreamExtended.newInitialPosition(initialPositionInStream);
this.skipShardSyncAtWorkerInitializationIfLeasesExist = DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST; this.skipShardSyncAtWorkerInitializationIfLeasesExist = DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST;
this.shardPrioritization = DEFAULT_SHARD_PRIORITIZATION; this.shardPrioritization = DEFAULT_SHARD_PRIORITIZATION;
this.recordsFetcherFactory = new SimpleRecordsFetcherFactory(this.maxRecords); this.recordsFetcherFactory = new SimpleRecordsFetcherFactory();
} }
/** /**

View file

@ -26,11 +26,12 @@ public interface RecordsFetcherFactory {
* @param getRecordsRetrievalStrategy GetRecordsRetrievalStrategy to be used with the GetRecordsCache * @param getRecordsRetrievalStrategy GetRecordsRetrievalStrategy to be used with the GetRecordsCache
* @param shardId ShardId of the shard that the fetcher will retrieve records for * @param shardId ShardId of the shard that the fetcher will retrieve records for
* @param metricsFactory MetricsFactory used to create metricScope * @param metricsFactory MetricsFactory used to create metricScope
* @param maxRecords Max number of records to be returned in a single get call
* *
* @return GetRecordsCache used to get records from Kinesis. * @return GetRecordsCache used to get records from Kinesis.
*/ */
GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId,
IMetricsFactory metricsFactory); IMetricsFactory metricsFactory, int maxRecords);
/** /**
* Sets the maximum number of ProcessRecordsInput objects the GetRecordsCache can hold, before further requests are * Sets the maximum number of ProcessRecordsInput objects the GetRecordsCache can hold, before further requests are

View file

@ -235,7 +235,7 @@ class ShardConsumer {
this.dataFetcher = kinesisDataFetcher; this.dataFetcher = kinesisDataFetcher;
this.getRecordsCache = config.getRecordsFetcherFactory().createRecordsFetcher( this.getRecordsCache = config.getRecordsFetcherFactory().createRecordsFetcher(
makeStrategy(this.dataFetcher, retryGetRecordsInSeconds, maxGetRecordsThreadPool, this.shardInfo), makeStrategy(this.dataFetcher, retryGetRecordsInSeconds, maxGetRecordsThreadPool, this.shardInfo),
this.getShardInfo().getShardId(), this.metricsFactory); this.getShardInfo().getShardId(), this.metricsFactory, this.config.getMaxRecords());
} }
/** /**

View file

@ -23,20 +23,15 @@ import lombok.extern.apachecommons.CommonsLog;
@CommonsLog @CommonsLog
public class SimpleRecordsFetcherFactory implements RecordsFetcherFactory { public class SimpleRecordsFetcherFactory implements RecordsFetcherFactory {
private final int maxRecords;
private int maxPendingProcessRecordsInput = 3; private int maxPendingProcessRecordsInput = 3;
private int maxByteSize = 8 * 1024 * 1024; private int maxByteSize = 8 * 1024 * 1024;
private int maxRecordsCount = 30000; private int maxRecordsCount = 30000;
private long idleMillisBetweenCalls = 1500L; private long idleMillisBetweenCalls = 1500L;
private DataFetchingStrategy dataFetchingStrategy = DataFetchingStrategy.DEFAULT; private DataFetchingStrategy dataFetchingStrategy = DataFetchingStrategy.DEFAULT;
public SimpleRecordsFetcherFactory(int maxRecords) {
this.maxRecords = maxRecords;
}
@Override @Override
public GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId, public GetRecordsCache createRecordsFetcher(GetRecordsRetrievalStrategy getRecordsRetrievalStrategy, String shardId,
IMetricsFactory metricsFactory) { IMetricsFactory metricsFactory, int maxRecords) {
if(dataFetchingStrategy.equals(DataFetchingStrategy.DEFAULT)) { if(dataFetchingStrategy.equals(DataFetchingStrategy.DEFAULT)) {
return new BlockingGetRecordsCache(maxRecords, getRecordsRetrievalStrategy); return new BlockingGetRecordsCache(maxRecords, getRecordsRetrievalStrategy);
} else { } else {

View file

@ -22,13 +22,13 @@ public class RecordsFetcherFactoryTest {
@Before @Before
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
recordsFetcherFactory = new SimpleRecordsFetcherFactory(1); recordsFetcherFactory = new SimpleRecordsFetcherFactory();
} }
@Test @Test
public void createDefaultRecordsFetcherTest() { public void createDefaultRecordsFetcherTest() {
GetRecordsCache recordsCache = recordsFetcherFactory.createRecordsFetcher(getRecordsRetrievalStrategy, shardId, GetRecordsCache recordsCache = recordsFetcherFactory.createRecordsFetcher(getRecordsRetrievalStrategy, shardId,
metricsFactory); metricsFactory, 1);
assertThat(recordsCache, instanceOf(BlockingGetRecordsCache.class)); assertThat(recordsCache, instanceOf(BlockingGetRecordsCache.class));
} }
@ -36,7 +36,7 @@ public class RecordsFetcherFactoryTest {
public void createPrefetchRecordsFetcherTest() { public void createPrefetchRecordsFetcherTest() {
recordsFetcherFactory.setDataFetchingStrategy(DataFetchingStrategy.PREFETCH_CACHED); recordsFetcherFactory.setDataFetchingStrategy(DataFetchingStrategy.PREFETCH_CACHED);
GetRecordsCache recordsCache = recordsFetcherFactory.createRecordsFetcher(getRecordsRetrievalStrategy, shardId, GetRecordsCache recordsCache = recordsFetcherFactory.createRecordsFetcher(getRecordsRetrievalStrategy, shardId,
metricsFactory); metricsFactory, 1);
assertThat(recordsCache, instanceOf(PrefetchGetRecordsCache.class)); assertThat(recordsCache, instanceOf(PrefetchGetRecordsCache.class));
} }

View file

@ -22,6 +22,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.argThat;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
@ -97,7 +98,6 @@ public class ShardConsumerTest {
// Use Executors.newFixedThreadPool since it returns ThreadPoolExecutor, which is // Use Executors.newFixedThreadPool since it returns ThreadPoolExecutor, which is
// ... a non-final public class, and so can be mocked and spied. // ... a non-final public class, and so can be mocked and spied.
private final ExecutorService executorService = Executors.newFixedThreadPool(1); private final ExecutorService executorService = Executors.newFixedThreadPool(1);
private final int maxRecords = 500;
private RecordsFetcherFactory recordsFetcherFactory; private RecordsFetcherFactory recordsFetcherFactory;
private GetRecordsCache getRecordsCache; private GetRecordsCache getRecordsCache;
@ -119,7 +119,7 @@ public class ShardConsumerTest {
public void setup() { public void setup() {
getRecordsCache = null; getRecordsCache = null;
recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory(maxRecords)); recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory());
when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory);
when(config.getLogWarningForTaskAfterMillis()).thenReturn(Optional.empty()); when(config.getLogWarningForTaskAfterMillis()).thenReturn(Optional.empty());
} }
@ -344,7 +344,7 @@ public class ShardConsumerTest {
getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords,
new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); new SynchronousGetRecordsRetrievalStrategy(dataFetcher)));
when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(),
any(IMetricsFactory.class))) any(IMetricsFactory.class), anyInt()))
.thenReturn(getRecordsCache); .thenReturn(getRecordsCache);
ShardConsumer consumer = ShardConsumer consumer =
@ -475,7 +475,7 @@ public class ShardConsumerTest {
getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords,
new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); new SynchronousGetRecordsRetrievalStrategy(dataFetcher)));
when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(),
any(IMetricsFactory.class))) any(IMetricsFactory.class), anyInt()))
.thenReturn(getRecordsCache); .thenReturn(getRecordsCache);
ShardConsumer consumer = ShardConsumer consumer =
@ -571,7 +571,7 @@ public class ShardConsumerTest {
final ExtendedSequenceNumber checkpointSequenceNumber = new ExtendedSequenceNumber("123"); final ExtendedSequenceNumber checkpointSequenceNumber = new ExtendedSequenceNumber("123");
final ExtendedSequenceNumber pendingCheckpointSequenceNumber = new ExtendedSequenceNumber("999"); final ExtendedSequenceNumber pendingCheckpointSequenceNumber = new ExtendedSequenceNumber("999");
when(leaseManager.getLease(anyString())).thenReturn(null); when(leaseManager.getLease(anyString())).thenReturn(null);
when(config.getRecordsFetcherFactory()).thenReturn(new SimpleRecordsFetcherFactory(2)); when(config.getRecordsFetcherFactory()).thenReturn(new SimpleRecordsFetcherFactory());
when(checkpoint.getCheckpointObject(anyString())).thenReturn( when(checkpoint.getCheckpointObject(anyString())).thenReturn(
new Checkpoint(checkpointSequenceNumber, pendingCheckpointSequenceNumber)); new Checkpoint(checkpointSequenceNumber, pendingCheckpointSequenceNumber));

View file

@ -21,6 +21,7 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
@ -172,7 +173,7 @@ public class WorkerTest {
@Before @Before
public void setup() { public void setup() {
config = spy(new KinesisClientLibConfiguration("app", null, null, null)); config = spy(new KinesisClientLibConfiguration("app", null, null, null));
recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory(500)); recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory());
when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory);
} }
@ -505,7 +506,7 @@ public class WorkerTest {
lease.setCheckpoint(new ExtendedSequenceNumber("2")); lease.setCheckpoint(new ExtendedSequenceNumber("2"));
initialLeases.add(lease); initialLeases.add(lease);
boolean callProcessRecordsForEmptyRecordList = true; boolean callProcessRecordsForEmptyRecordList = true;
RecordsFetcherFactory recordsFetcherFactory = new SimpleRecordsFetcherFactory(500); RecordsFetcherFactory recordsFetcherFactory = new SimpleRecordsFetcherFactory();
recordsFetcherFactory.setIdleMillisBetweenCalls(0L); recordsFetcherFactory.setIdleMillisBetweenCalls(0L);
when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory);
runAndTestWorker(shardList, threadPoolSize, initialLeases, callProcessRecordsForEmptyRecordList, numberOfRecordsPerShard, config); runAndTestWorker(shardList, threadPoolSize, initialLeases, callProcessRecordsForEmptyRecordList, numberOfRecordsPerShard, config);
@ -622,7 +623,7 @@ public class WorkerTest {
GetRecordsCache getRecordsCache = mock(GetRecordsCache.class); GetRecordsCache getRecordsCache = mock(GetRecordsCache.class);
when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory);
when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(),
any(IMetricsFactory.class))) any(IMetricsFactory.class), anyInt()))
.thenReturn(getRecordsCache); .thenReturn(getRecordsCache);
when(getRecordsCache.getNextResult()).thenReturn(new ProcessRecordsInput().withRecords(Collections.emptyList()).withMillisBehindLatest(0L)); when(getRecordsCache.getNextResult()).thenReturn(new ProcessRecordsInput().withRecords(Collections.emptyList()).withMillisBehindLatest(0L));