diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java index 402c95c6..d4e2a898 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java @@ -519,7 +519,7 @@ public class Worker implements Runnable { boolean foundCompletedShard = false; Set assignedShards = new HashSet<>(); for (ShardInfo shardInfo : getShardInfoForAssignments()) { - ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory, recordsFetcherFactory); + ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory); if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) { foundCompletedShard = true; } else { @@ -891,11 +891,9 @@ public class Worker implements Runnable { * Kinesis shard info * @param processorFactory * RecordProcessor factory - * @param fetcherFactory - * RecordFetcher factory * @return ShardConsumer for the shard */ - ShardConsumer createOrGetShardConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory, RecordsFetcherFactory fetcherFactory) { + ShardConsumer createOrGetShardConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory) { ShardConsumer consumer = shardInfoShardConsumerMap.get(shardInfo); // Instantiate a new consumer if we don't have one, or the one we // had was from an earlier @@ -904,17 +902,17 @@ public class Worker implements Runnable { // completely processed (shutdown reason terminate). if ((consumer == null) || (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) { - consumer = buildConsumer(shardInfo, processorFactory, fetcherFactory); + consumer = buildConsumer(shardInfo, processorFactory); shardInfoShardConsumerMap.put(shardInfo, consumer); wlog.infoForce("Created new shardConsumer for : " + shardInfo); } return consumer; } - protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory, RecordsFetcherFactory fetcherFactory) { + protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory processorFactory) { IRecordProcessor recordProcessor = processorFactory.createProcessor(); - return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, fetcherFactory, + return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, recordsFetcherFactory, leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion, executorService, metricsFactory, taskBackoffTimeMillis, skipShardSyncAtWorkerInitializationIfLeasesExist, retryGetRecordsInSeconds, maxGetRecordsThreadPool); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStatesTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStatesTest.java index 307aa6b8..89a582a4 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStatesTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ConsumerStatesTest.java @@ -57,6 +57,8 @@ public class ConsumerStatesTest { @Mock private IRecordProcessor recordProcessor; @Mock + private RecordsFetcherFactory recordsFetcherFactory; + @Mock private RecordProcessorCheckpointer recordProcessorCheckpointer; @Mock private ExecutorService executorService; @@ -76,6 +78,10 @@ public class ConsumerStatesTest { private IKinesisProxy kinesisProxy; @Mock private InitialPositionInStreamExtended initialPositionInStream; + @Mock + private SynchronousGetRecordsRetrievalStrategy getRecordsRetrievalStrategy; + @Mock + private GetRecordsCache recordsFetcher; private long parentShardPollIntervalMillis = 0xCAFE; private boolean cleanupLeasesOfCompletedShards = true; @@ -86,6 +92,7 @@ public class ConsumerStatesTest { public void setup() { when(consumer.getStreamConfig()).thenReturn(streamConfig); when(consumer.getRecordProcessor()).thenReturn(recordProcessor); + when(consumer.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); when(consumer.getRecordProcessorCheckpointer()).thenReturn(recordProcessorCheckpointer); when(consumer.getExecutorService()).thenReturn(executorService); when(consumer.getShardInfo()).thenReturn(shardInfo); @@ -153,68 +160,6 @@ public class ConsumerStatesTest { assertThat(state.getTaskType(), equalTo(TaskType.INITIALIZE)); } - @Test - public void processingStateTestSynchronous() { - when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.empty()); - when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.empty()); - - ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState(); - ITask task = state.createTask(consumer); - - assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); - assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor))); - assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer", - equalTo(recordProcessorCheckpointer))); - assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher))); - assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig))); - assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis))); - assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(SynchronousGetRecordsRetrievalStrategy.class) )); - - assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState())); - - assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE), - equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState())); - assertThat(state.shutdownTransition(ShutdownReason.TERMINATE), - equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState())); - assertThat(state.shutdownTransition(ShutdownReason.REQUESTED), - equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState())); - - assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING)); - assertThat(state.getTaskType(), equalTo(TaskType.PROCESS)); - - } - - @Test - public void processingStateTestAsynchronous() { - when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.of(1)); - when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.of(2)); - - ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState(); - ITask task = state.createTask(consumer); - - assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo))); - assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor))); - assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer", - equalTo(recordProcessorCheckpointer))); - assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher))); - assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig))); - assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis))); - assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(AsynchronousGetRecordsRetrievalStrategy.class) )); - - assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState())); - - assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE), - equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState())); - assertThat(state.shutdownTransition(ShutdownReason.TERMINATE), - equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState())); - assertThat(state.shutdownTransition(ShutdownReason.REQUESTED), - equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState())); - - assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING)); - assertThat(state.getTaskType(), equalTo(TaskType.PROCESS)); - - } - @Test public void shutdownRequestState() { ConsumerState state = ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState(); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java index 0c47e9b9..892d1483 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ProcessTaskTest.java @@ -78,6 +78,10 @@ public class ProcessTaskTest { private ThrottlingReporter throttlingReporter; @Mock private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy; + @Mock + private RecordsFetcherFactory mockRecordsFetcherFactory; + @Mock + private GetRecordsCache mockRecordsFetcher; private List processedRecords; private ExtendedSequenceNumber newLargestPermittedCheckpointValue; @@ -94,8 +98,9 @@ public class ProcessTaskTest { skipCheckpointValidationValue, INITIAL_POSITION_LATEST); final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null); + when(mockRecordsFetcherFactory.createRecordsFetcher(mockGetRecordsRetrievalStrategy)).thenReturn(mockRecordsFetcher); processTask = new ProcessTask( - shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis, + shardInfo, config, mockRecordProcessor, mockRecordsFetcherFactory, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis, KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy); } @@ -103,13 +108,13 @@ public class ProcessTaskTest { public void testProcessTaskWithProvisionedThroughputExceededException() { // Set data fetcher to throw exception doReturn(false).when(mockDataFetcher).isShardEndReached(); - doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy) - .getRecords(maxRecords); + doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockRecordsFetcher) + .getNextResult(); TaskResult result = processTask.call(); verify(throttlingReporter).throttled(); verify(throttlingReporter, never()).success(); - verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords)); + verify(mockRecordsFetcher).getNextResult(); assertTrue("Result should contain ProvisionedThroughputExceededException", result.getException() instanceof ProvisionedThroughputExceededException); } @@ -117,10 +122,10 @@ public class ProcessTaskTest { @Test public void testProcessTaskWithNonExistentStream() { // Data fetcher returns a null Result when the stream does not exist - doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords); + doReturn(new GetRecordsResult().withRecords(Collections.emptyList())).when(mockRecordsFetcher).getNextResult(); TaskResult result = processTask.call(); - verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords)); + verify(mockRecordsFetcher).getNextResult(); assertNull("Task should not throw an exception", result.getException()); } @@ -304,14 +309,14 @@ public class ProcessTaskTest { private void testWithRecords(List records, ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber largestPermittedCheckpointValue) { - when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn( + when(mockRecordsFetcher.getNextResult()).thenReturn( new GetRecordsResult().withRecords(records)); when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue); when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue); processTask.call(); verify(throttlingReporter).success(); verify(throttlingReporter, never()).throttled(); - verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt()); + verify(mockRecordsFetcher).getNextResult(); ArgumentCaptor priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); verify(mockRecordProcessor).processRecords(priCaptor.capture()); processedRecords = priCaptor.getValue().getRecords();