fixed tests
This commit is contained in:
parent
1c07b45166
commit
024d86da76
3 changed files with 25 additions and 77 deletions
|
|
@ -519,7 +519,7 @@ public class Worker implements Runnable {
|
|||
boolean foundCompletedShard = false;
|
||||
Set<ShardInfo> assignedShards = new HashSet<>();
|
||||
for (ShardInfo shardInfo : getShardInfoForAssignments()) {
|
||||
ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory, recordsFetcherFactory);
|
||||
ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory);
|
||||
if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) {
|
||||
foundCompletedShard = true;
|
||||
} else {
|
||||
|
|
@ -891,11 +891,9 @@ public class Worker implements Runnable {
|
|||
* Kinesis shard info
|
||||
* @param processorFactory
|
||||
* RecordProcessor factory
|
||||
* @param fetcherFactory
|
||||
* RecordFetcher factory
|
||||
* @return ShardConsumer for the shard
|
||||
*/
|
||||
ShardConsumer createOrGetShardConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory, RecordsFetcherFactory fetcherFactory) {
|
||||
ShardConsumer createOrGetShardConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory) {
|
||||
ShardConsumer consumer = shardInfoShardConsumerMap.get(shardInfo);
|
||||
// Instantiate a new consumer if we don't have one, or the one we
|
||||
// had was from an earlier
|
||||
|
|
@ -904,17 +902,17 @@ public class Worker implements Runnable {
|
|||
// completely processed (shutdown reason terminate).
|
||||
if ((consumer == null)
|
||||
|| (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) {
|
||||
consumer = buildConsumer(shardInfo, processorFactory, fetcherFactory);
|
||||
consumer = buildConsumer(shardInfo, processorFactory);
|
||||
shardInfoShardConsumerMap.put(shardInfo, consumer);
|
||||
wlog.infoForce("Created new shardConsumer for : " + shardInfo);
|
||||
}
|
||||
return consumer;
|
||||
}
|
||||
|
||||
protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory, RecordsFetcherFactory fetcherFactory) {
|
||||
protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory) {
|
||||
IRecordProcessor recordProcessor = processorFactory.createProcessor();
|
||||
|
||||
return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, fetcherFactory,
|
||||
return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, recordsFetcherFactory,
|
||||
leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion,
|
||||
executorService, metricsFactory, taskBackoffTimeMillis,
|
||||
skipShardSyncAtWorkerInitializationIfLeasesExist, retryGetRecordsInSeconds, maxGetRecordsThreadPool);
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ public class ConsumerStatesTest {
|
|||
@Mock
|
||||
private IRecordProcessor recordProcessor;
|
||||
@Mock
|
||||
private RecordsFetcherFactory recordsFetcherFactory;
|
||||
@Mock
|
||||
private RecordProcessorCheckpointer recordProcessorCheckpointer;
|
||||
@Mock
|
||||
private ExecutorService executorService;
|
||||
|
|
@ -76,6 +78,10 @@ public class ConsumerStatesTest {
|
|||
private IKinesisProxy kinesisProxy;
|
||||
@Mock
|
||||
private InitialPositionInStreamExtended initialPositionInStream;
|
||||
@Mock
|
||||
private SynchronousGetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
|
||||
@Mock
|
||||
private GetRecordsCache recordsFetcher;
|
||||
|
||||
private long parentShardPollIntervalMillis = 0xCAFE;
|
||||
private boolean cleanupLeasesOfCompletedShards = true;
|
||||
|
|
@ -86,6 +92,7 @@ public class ConsumerStatesTest {
|
|||
public void setup() {
|
||||
when(consumer.getStreamConfig()).thenReturn(streamConfig);
|
||||
when(consumer.getRecordProcessor()).thenReturn(recordProcessor);
|
||||
when(consumer.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory);
|
||||
when(consumer.getRecordProcessorCheckpointer()).thenReturn(recordProcessorCheckpointer);
|
||||
when(consumer.getExecutorService()).thenReturn(executorService);
|
||||
when(consumer.getShardInfo()).thenReturn(shardInfo);
|
||||
|
|
@ -153,68 +160,6 @@ public class ConsumerStatesTest {
|
|||
assertThat(state.getTaskType(), equalTo(TaskType.INITIALIZE));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processingStateTestSynchronous() {
|
||||
when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.empty());
|
||||
when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.empty());
|
||||
|
||||
ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState();
|
||||
ITask task = state.createTask(consumer);
|
||||
|
||||
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
|
||||
assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor)));
|
||||
assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer",
|
||||
equalTo(recordProcessorCheckpointer)));
|
||||
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
|
||||
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
|
||||
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
|
||||
assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(SynchronousGetRecordsRetrievalStrategy.class) ));
|
||||
|
||||
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));
|
||||
|
||||
assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE),
|
||||
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
|
||||
assertThat(state.shutdownTransition(ShutdownReason.TERMINATE),
|
||||
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
|
||||
assertThat(state.shutdownTransition(ShutdownReason.REQUESTED),
|
||||
equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState()));
|
||||
|
||||
assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING));
|
||||
assertThat(state.getTaskType(), equalTo(TaskType.PROCESS));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processingStateTestAsynchronous() {
|
||||
when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.of(1));
|
||||
when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.of(2));
|
||||
|
||||
ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState();
|
||||
ITask task = state.createTask(consumer);
|
||||
|
||||
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
|
||||
assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor)));
|
||||
assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer",
|
||||
equalTo(recordProcessorCheckpointer)));
|
||||
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
|
||||
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
|
||||
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
|
||||
assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(AsynchronousGetRecordsRetrievalStrategy.class) ));
|
||||
|
||||
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));
|
||||
|
||||
assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE),
|
||||
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
|
||||
assertThat(state.shutdownTransition(ShutdownReason.TERMINATE),
|
||||
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
|
||||
assertThat(state.shutdownTransition(ShutdownReason.REQUESTED),
|
||||
equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState()));
|
||||
|
||||
assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING));
|
||||
assertThat(state.getTaskType(), equalTo(TaskType.PROCESS));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shutdownRequestState() {
|
||||
ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState();
|
||||
|
|
|
|||
|
|
@ -78,6 +78,10 @@ public class ProcessTaskTest {
|
|||
private ThrottlingReporter throttlingReporter;
|
||||
@Mock
|
||||
private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy;
|
||||
@Mock
|
||||
private RecordsFetcherFactory mockRecordsFetcherFactory;
|
||||
@Mock
|
||||
private GetRecordsCache mockRecordsFetcher;
|
||||
|
||||
private List<Record> processedRecords;
|
||||
private ExtendedSequenceNumber newLargestPermittedCheckpointValue;
|
||||
|
|
@ -94,8 +98,9 @@ public class ProcessTaskTest {
|
|||
skipCheckpointValidationValue,
|
||||
INITIAL_POSITION_LATEST);
|
||||
final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null);
|
||||
when(mockRecordsFetcherFactory.createRecordsFetcher(mockGetRecordsRetrievalStrategy)).thenReturn(mockRecordsFetcher);
|
||||
processTask = new ProcessTask(
|
||||
shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
|
||||
shardInfo, config, mockRecordProcessor, mockRecordsFetcherFactory, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
|
||||
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy);
|
||||
}
|
||||
|
||||
|
|
@ -103,13 +108,13 @@ public class ProcessTaskTest {
|
|||
public void testProcessTaskWithProvisionedThroughputExceededException() {
|
||||
// Set data fetcher to throw exception
|
||||
doReturn(false).when(mockDataFetcher).isShardEndReached();
|
||||
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy)
|
||||
.getRecords(maxRecords);
|
||||
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockRecordsFetcher)
|
||||
.getNextResult();
|
||||
|
||||
TaskResult result = processTask.call();
|
||||
verify(throttlingReporter).throttled();
|
||||
verify(throttlingReporter, never()).success();
|
||||
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
|
||||
verify(mockRecordsFetcher).getNextResult();
|
||||
assertTrue("Result should contain ProvisionedThroughputExceededException",
|
||||
result.getException() instanceof ProvisionedThroughputExceededException);
|
||||
}
|
||||
|
|
@ -117,10 +122,10 @@ public class ProcessTaskTest {
|
|||
@Test
|
||||
public void testProcessTaskWithNonExistentStream() {
|
||||
// Data fetcher returns a null Result when the stream does not exist
|
||||
doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords);
|
||||
doReturn(new GetRecordsResult().withRecords(Collections.emptyList())).when(mockRecordsFetcher).getNextResult();
|
||||
|
||||
TaskResult result = processTask.call();
|
||||
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
|
||||
verify(mockRecordsFetcher).getNextResult();
|
||||
assertNull("Task should not throw an exception", result.getException());
|
||||
}
|
||||
|
||||
|
|
@ -304,14 +309,14 @@ public class ProcessTaskTest {
|
|||
private void testWithRecords(List<Record> records,
|
||||
ExtendedSequenceNumber lastCheckpointValue,
|
||||
ExtendedSequenceNumber largestPermittedCheckpointValue) {
|
||||
when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(
|
||||
when(mockRecordsFetcher.getNextResult()).thenReturn(
|
||||
new GetRecordsResult().withRecords(records));
|
||||
when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue);
|
||||
when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue);
|
||||
processTask.call();
|
||||
verify(throttlingReporter).success();
|
||||
verify(throttlingReporter, never()).throttled();
|
||||
verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt());
|
||||
verify(mockRecordsFetcher).getNextResult();
|
||||
ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class);
|
||||
verify(mockRecordProcessor).processRecords(priCaptor.capture());
|
||||
processedRecords = priCaptor.getValue().getRecords();
|
||||
|
|
|
|||
Loading…
Reference in a new issue