fixed tests

This commit is contained in:
Wei 2017-09-20 14:25:23 -07:00
parent 1c07b45166
commit 024d86da76
3 changed files with 25 additions and 77 deletions

View file

@ -519,7 +519,7 @@ public class Worker implements Runnable {
boolean foundCompletedShard = false; boolean foundCompletedShard = false;
Set<ShardInfo> assignedShards = new HashSet<>(); Set<ShardInfo> assignedShards = new HashSet<>();
for (ShardInfo shardInfo : getShardInfoForAssignments()) { for (ShardInfo shardInfo : getShardInfoForAssignments()) {
ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory, recordsFetcherFactory); ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory);
if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) { if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) {
foundCompletedShard = true; foundCompletedShard = true;
} else { } else {
@ -891,11 +891,9 @@ public class Worker implements Runnable {
* Kinesis shard info * Kinesis shard info
* @param processorFactory * @param processorFactory
* RecordProcessor factory * RecordProcessor factory
* @param fetcherFactory
* RecordFetcher factory
* @return ShardConsumer for the shard * @return ShardConsumer for the shard
*/ */
ShardConsumer createOrGetShardConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory, RecordsFetcherFactory fetcherFactory) { ShardConsumer createOrGetShardConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory) {
ShardConsumer consumer = shardInfoShardConsumerMap.get(shardInfo); ShardConsumer consumer = shardInfoShardConsumerMap.get(shardInfo);
// Instantiate a new consumer if we don't have one, or the one we // Instantiate a new consumer if we don't have one, or the one we
// had was from an earlier // had was from an earlier
@ -904,17 +902,17 @@ public class Worker implements Runnable {
// completely processed (shutdown reason terminate). // completely processed (shutdown reason terminate).
if ((consumer == null) if ((consumer == null)
|| (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) { || (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) {
consumer = buildConsumer(shardInfo, processorFactory, fetcherFactory); consumer = buildConsumer(shardInfo, processorFactory);
shardInfoShardConsumerMap.put(shardInfo, consumer); shardInfoShardConsumerMap.put(shardInfo, consumer);
wlog.infoForce("Created new shardConsumer for : " + shardInfo); wlog.infoForce("Created new shardConsumer for : " + shardInfo);
} }
return consumer; return consumer;
} }
protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory, RecordsFetcherFactory fetcherFactory) { protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory) {
IRecordProcessor recordProcessor = processorFactory.createProcessor(); IRecordProcessor recordProcessor = processorFactory.createProcessor();
return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, fetcherFactory, return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, recordsFetcherFactory,
leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion, leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion,
executorService, metricsFactory, taskBackoffTimeMillis, executorService, metricsFactory, taskBackoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist, retryGetRecordsInSeconds, maxGetRecordsThreadPool); skipShardSyncAtWorkerInitializationIfLeasesExist, retryGetRecordsInSeconds, maxGetRecordsThreadPool);

View file

@ -57,6 +57,8 @@ public class ConsumerStatesTest {
@Mock @Mock
private IRecordProcessor recordProcessor; private IRecordProcessor recordProcessor;
@Mock @Mock
private RecordsFetcherFactory recordsFetcherFactory;
@Mock
private RecordProcessorCheckpointer recordProcessorCheckpointer; private RecordProcessorCheckpointer recordProcessorCheckpointer;
@Mock @Mock
private ExecutorService executorService; private ExecutorService executorService;
@ -76,6 +78,10 @@ public class ConsumerStatesTest {
private IKinesisProxy kinesisProxy; private IKinesisProxy kinesisProxy;
@Mock @Mock
private InitialPositionInStreamExtended initialPositionInStream; private InitialPositionInStreamExtended initialPositionInStream;
@Mock
private SynchronousGetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
@Mock
private GetRecordsCache recordsFetcher;
private long parentShardPollIntervalMillis = 0xCAFE; private long parentShardPollIntervalMillis = 0xCAFE;
private boolean cleanupLeasesOfCompletedShards = true; private boolean cleanupLeasesOfCompletedShards = true;
@ -86,6 +92,7 @@ public class ConsumerStatesTest {
public void setup() { public void setup() {
when(consumer.getStreamConfig()).thenReturn(streamConfig); when(consumer.getStreamConfig()).thenReturn(streamConfig);
when(consumer.getRecordProcessor()).thenReturn(recordProcessor); when(consumer.getRecordProcessor()).thenReturn(recordProcessor);
when(consumer.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory);
when(consumer.getRecordProcessorCheckpointer()).thenReturn(recordProcessorCheckpointer); when(consumer.getRecordProcessorCheckpointer()).thenReturn(recordProcessorCheckpointer);
when(consumer.getExecutorService()).thenReturn(executorService); when(consumer.getExecutorService()).thenReturn(executorService);
when(consumer.getShardInfo()).thenReturn(shardInfo); when(consumer.getShardInfo()).thenReturn(shardInfo);
@ -153,68 +160,6 @@ public class ConsumerStatesTest {
assertThat(state.getTaskType(), equalTo(TaskType.INITIALIZE)); assertThat(state.getTaskType(), equalTo(TaskType.INITIALIZE));
} }
@Test
public void processingStateTestSynchronous() {
when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.empty());
when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.empty());
ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState();
ITask task = state.createTask(consumer);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor)));
assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer",
equalTo(recordProcessorCheckpointer)));
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(SynchronousGetRecordsRetrievalStrategy.class) ));
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.TERMINATE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.REQUESTED),
equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState()));
assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING));
assertThat(state.getTaskType(), equalTo(TaskType.PROCESS));
}
@Test
public void processingStateTestAsynchronous() {
when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.of(1));
when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.of(2));
ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState();
ITask task = state.createTask(consumer);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor)));
assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer",
equalTo(recordProcessorCheckpointer)));
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(AsynchronousGetRecordsRetrievalStrategy.class) ));
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.TERMINATE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.REQUESTED),
equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState()));
assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING));
assertThat(state.getTaskType(), equalTo(TaskType.PROCESS));
}
@Test @Test
public void shutdownRequestState() { public void shutdownRequestState() {
ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState(); ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState();

View file

@ -78,6 +78,10 @@ public class ProcessTaskTest {
private ThrottlingReporter throttlingReporter; private ThrottlingReporter throttlingReporter;
@Mock @Mock
private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy; private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy;
@Mock
private RecordsFetcherFactory mockRecordsFetcherFactory;
@Mock
private GetRecordsCache mockRecordsFetcher;
private List<Record> processedRecords; private List<Record> processedRecords;
private ExtendedSequenceNumber newLargestPermittedCheckpointValue; private ExtendedSequenceNumber newLargestPermittedCheckpointValue;
@ -94,8 +98,9 @@ public class ProcessTaskTest {
skipCheckpointValidationValue, skipCheckpointValidationValue,
INITIAL_POSITION_LATEST); INITIAL_POSITION_LATEST);
final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null); final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null);
when(mockRecordsFetcherFactory.createRecordsFetcher(mockGetRecordsRetrievalStrategy)).thenReturn(mockRecordsFetcher);
processTask = new ProcessTask( processTask = new ProcessTask(
shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis, shardInfo, config, mockRecordProcessor, mockRecordsFetcherFactory, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy); KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy);
} }
@ -103,13 +108,13 @@ public class ProcessTaskTest {
public void testProcessTaskWithProvisionedThroughputExceededException() { public void testProcessTaskWithProvisionedThroughputExceededException() {
// Set data fetcher to throw exception // Set data fetcher to throw exception
doReturn(false).when(mockDataFetcher).isShardEndReached(); doReturn(false).when(mockDataFetcher).isShardEndReached();
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy) doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockRecordsFetcher)
.getRecords(maxRecords); .getNextResult();
TaskResult result = processTask.call(); TaskResult result = processTask.call();
verify(throttlingReporter).throttled(); verify(throttlingReporter).throttled();
verify(throttlingReporter, never()).success(); verify(throttlingReporter, never()).success();
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords)); verify(mockRecordsFetcher).getNextResult();
assertTrue("Result should contain ProvisionedThroughputExceededException", assertTrue("Result should contain ProvisionedThroughputExceededException",
result.getException() instanceof ProvisionedThroughputExceededException); result.getException() instanceof ProvisionedThroughputExceededException);
} }
@ -117,10 +122,10 @@ public class ProcessTaskTest {
@Test @Test
public void testProcessTaskWithNonExistentStream() { public void testProcessTaskWithNonExistentStream() {
// Data fetcher returns a null Result when the stream does not exist // Data fetcher returns a null Result when the stream does not exist
doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords); doReturn(new GetRecordsResult().withRecords(Collections.emptyList())).when(mockRecordsFetcher).getNextResult();
TaskResult result = processTask.call(); TaskResult result = processTask.call();
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords)); verify(mockRecordsFetcher).getNextResult();
assertNull("Task should not throw an exception", result.getException()); assertNull("Task should not throw an exception", result.getException());
} }
@ -304,14 +309,14 @@ public class ProcessTaskTest {
private void testWithRecords(List<Record> records, private void testWithRecords(List<Record> records,
ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber lastCheckpointValue,
ExtendedSequenceNumber largestPermittedCheckpointValue) { ExtendedSequenceNumber largestPermittedCheckpointValue) {
when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn( when(mockRecordsFetcher.getNextResult()).thenReturn(
new GetRecordsResult().withRecords(records)); new GetRecordsResult().withRecords(records));
when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue); when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue);
when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue); when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue);
processTask.call(); processTask.call();
verify(throttlingReporter).success(); verify(throttlingReporter).success();
verify(throttlingReporter, never()).throttled(); verify(throttlingReporter, never()).throttled();
verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt()); verify(mockRecordsFetcher).getNextResult();
ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class);
verify(mockRecordProcessor).processRecords(priCaptor.capture()); verify(mockRecordProcessor).processRecords(priCaptor.capture());
processedRecords = priCaptor.getValue().getRecords(); processedRecords = priCaptor.getValue().getRecords();