fixed tests

This commit is contained in:
Wei 2017-09-20 14:25:23 -07:00
parent 1c07b45166
commit 024d86da76
3 changed files with 25 additions and 77 deletions

View file

@ -519,7 +519,7 @@ public class Worker implements Runnable {
boolean foundCompletedShard = false;
Set<ShardInfo> assignedShards = new HashSet<>();
for (ShardInfo shardInfo : getShardInfoForAssignments()) {
ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory, recordsFetcherFactory);
ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory);
if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) {
foundCompletedShard = true;
} else {
@ -891,11 +891,9 @@ public class Worker implements Runnable {
* Kinesis shard info
* @param processorFactory
* RecordProcessor factory
* @param fetcherFactory
* RecordFetcher factory
* @return ShardConsumer for the shard
*/
ShardConsumer createOrGetShardConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory, RecordsFetcherFactory fetcherFactory) {
ShardConsumer createOrGetShardConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory) {
ShardConsumer consumer = shardInfoShardConsumerMap.get(shardInfo);
// Instantiate a new consumer if we don't have one, or the one we
// had was from an earlier
@ -904,17 +902,17 @@ public class Worker implements Runnable {
// completely processed (shutdown reason terminate).
if ((consumer == null)
|| (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) {
consumer = buildConsumer(shardInfo, processorFactory, fetcherFactory);
consumer = buildConsumer(shardInfo, processorFactory);
shardInfoShardConsumerMap.put(shardInfo, consumer);
wlog.infoForce("Created new shardConsumer for : " + shardInfo);
}
return consumer;
}
protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory, RecordsFetcherFactory fetcherFactory) {
protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory) {
IRecordProcessor recordProcessor = processorFactory.createProcessor();
return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, fetcherFactory,
return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, recordsFetcherFactory,
leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion,
executorService, metricsFactory, taskBackoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist, retryGetRecordsInSeconds, maxGetRecordsThreadPool);

View file

@ -57,6 +57,8 @@ public class ConsumerStatesTest {
@Mock
private IRecordProcessor recordProcessor;
@Mock
private RecordsFetcherFactory recordsFetcherFactory;
@Mock
private RecordProcessorCheckpointer recordProcessorCheckpointer;
@Mock
private ExecutorService executorService;
@ -76,6 +78,10 @@ public class ConsumerStatesTest {
private IKinesisProxy kinesisProxy;
@Mock
private InitialPositionInStreamExtended initialPositionInStream;
@Mock
private SynchronousGetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
@Mock
private GetRecordsCache recordsFetcher;
private long parentShardPollIntervalMillis = 0xCAFE;
private boolean cleanupLeasesOfCompletedShards = true;
@ -86,6 +92,7 @@ public class ConsumerStatesTest {
public void setup() {
when(consumer.getStreamConfig()).thenReturn(streamConfig);
when(consumer.getRecordProcessor()).thenReturn(recordProcessor);
when(consumer.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory);
when(consumer.getRecordProcessorCheckpointer()).thenReturn(recordProcessorCheckpointer);
when(consumer.getExecutorService()).thenReturn(executorService);
when(consumer.getShardInfo()).thenReturn(shardInfo);
@ -153,68 +160,6 @@ public class ConsumerStatesTest {
assertThat(state.getTaskType(), equalTo(TaskType.INITIALIZE));
}
@Test
public void processingStateTestSynchronous() {
when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.empty());
when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.empty());
ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState();
ITask task = state.createTask(consumer);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor)));
assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer",
equalTo(recordProcessorCheckpointer)));
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(SynchronousGetRecordsRetrievalStrategy.class) ));
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.TERMINATE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.REQUESTED),
equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState()));
assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING));
assertThat(state.getTaskType(), equalTo(TaskType.PROCESS));
}
@Test
public void processingStateTestAsynchronous() {
when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.of(1));
when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.of(2));
ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState();
ITask task = state.createTask(consumer);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor)));
assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer",
equalTo(recordProcessorCheckpointer)));
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(AsynchronousGetRecordsRetrievalStrategy.class) ));
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.TERMINATE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.REQUESTED),
equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState()));
assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING));
assertThat(state.getTaskType(), equalTo(TaskType.PROCESS));
}
@Test
public void shutdownRequestState() {
ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState();

View file

@ -78,6 +78,10 @@ public class ProcessTaskTest {
private ThrottlingReporter throttlingReporter;
@Mock
private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy;
@Mock
private RecordsFetcherFactory mockRecordsFetcherFactory;
@Mock
private GetRecordsCache mockRecordsFetcher;
private List<Record> processedRecords;
private ExtendedSequenceNumber newLargestPermittedCheckpointValue;
@ -94,8 +98,9 @@ public class ProcessTaskTest {
skipCheckpointValidationValue,
INITIAL_POSITION_LATEST);
final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null);
when(mockRecordsFetcherFactory.createRecordsFetcher(mockGetRecordsRetrievalStrategy)).thenReturn(mockRecordsFetcher);
processTask = new ProcessTask(
shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
shardInfo, config, mockRecordProcessor, mockRecordsFetcherFactory, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy);
}
@ -103,13 +108,13 @@ public class ProcessTaskTest {
public void testProcessTaskWithProvisionedThroughputExceededException() {
// Set data fetcher to throw exception
doReturn(false).when(mockDataFetcher).isShardEndReached();
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy)
.getRecords(maxRecords);
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockRecordsFetcher)
.getNextResult();
TaskResult result = processTask.call();
verify(throttlingReporter).throttled();
verify(throttlingReporter, never()).success();
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
verify(mockRecordsFetcher).getNextResult();
assertTrue("Result should contain ProvisionedThroughputExceededException",
result.getException() instanceof ProvisionedThroughputExceededException);
}
@ -117,10 +122,10 @@ public class ProcessTaskTest {
@Test
public void testProcessTaskWithNonExistentStream() {
// Data fetcher returns a null Result when the stream does not exist
doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords);
doReturn(new GetRecordsResult().withRecords(Collections.emptyList())).when(mockRecordsFetcher).getNextResult();
TaskResult result = processTask.call();
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
verify(mockRecordsFetcher).getNextResult();
assertNull("Task should not throw an exception", result.getException());
}
@ -304,14 +309,14 @@ public class ProcessTaskTest {
private void testWithRecords(List<Record> records,
ExtendedSequenceNumber lastCheckpointValue,
ExtendedSequenceNumber largestPermittedCheckpointValue) {
when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(
when(mockRecordsFetcher.getNextResult()).thenReturn(
new GetRecordsResult().withRecords(records));
when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue);
when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue);
processTask.call();
verify(throttlingReporter).success();
verify(throttlingReporter, never()).throttled();
verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt());
verify(mockRecordsFetcher).getNextResult();
ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class);
verify(mockRecordProcessor).processRecords(priCaptor.capture());
processedRecords = priCaptor.getValue().getRecords();