diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java index 2fc19f8f..3dd49fcf 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java @@ -105,6 +105,8 @@ public class ShardConsumerTest { private GetRecordsCache getRecordsCache; + private KinesisDataFetcher dataFetcher; + @Mock private IRecordProcessor processor; @Mock @@ -121,6 +123,7 @@ public class ShardConsumerTest { @Before public void setup() { getRecordsCache = null; + dataFetcher = null; recordsFetcherFactory = spy(new SimpleRecordsFetcherFactory(maxRecords)); when(config.getRecordsFetcherFactory()).thenReturn(recordsFetcherFactory); @@ -342,7 +345,7 @@ public class ShardConsumerTest { ) ); - KinesisDataFetcher dataFetcher = new KinesisDataFetcher(streamConfig.getStreamProxy(), shardInfo); + dataFetcher = new KinesisDataFetcher(streamConfig.getStreamProxy(), shardInfo); getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); @@ -474,18 +477,42 @@ public class ShardConsumerTest { skipCheckpointValidationValue, INITIAL_POSITION_LATEST); ShardInfo shardInfo = new ShardInfo(streamShardId, testConcurrencyToken, null, null); + + dataFetcher = new KinesisDataFetcher(streamConfig.getStreamProxy(), shardInfo); + + getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, + new SynchronousGetRecordsRetrievalStrategy(dataFetcher))); + when(recordsFetcherFactory.createRecordsFetcher(any(GetRecordsRetrievalStrategy.class), anyString(), + any(IMetricsFactory.class))) + .thenReturn(getRecordsCache); + + RecordProcessorCheckpointer recordProcessorCheckpointer = new RecordProcessorCheckpointer( + shardInfo, + checkpoint, + new SequenceNumberValidator( + streamConfig.getStreamProxy(), + shardInfo.getShardId(), + streamConfig.shouldValidateSequenceNumberBeforeCheckpointing() + ) + ); + ShardConsumer consumer = new ShardConsumer(shardInfo, streamConfig, checkpoint, processor, + recordProcessorCheckpointer, leaseManager, parentShardPollIntervalMillis, cleanupLeasesOfCompletedShards, executorService, metricsFactory, taskBackoffTimeMillis, - KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST); + KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, + dataFetcher, + Optional.empty(), + Optional.empty(), + config); assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.WAITING_ON_PARENT_SHARDS))); consumer.consumeShard(); // check on parent shards @@ -494,12 +521,13 @@ public class ShardConsumerTest { assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.INITIALIZING))); consumer.consumeShard(); // initialize processor.getInitializeLatch().await(5, TimeUnit.SECONDS); + verify(getRecordsCache).start(); // We expect to process all records in numRecs calls for (int i = 0; i < numRecs;) { boolean newTaskSubmitted = consumer.consumeShard(); if (newTaskSubmitted) { - LOG.debug("New processing task was submitted, call # " + i); + LOG.info("New processing task was submitted, call # " + i); assertThat(consumer.getCurrentState(), is(equalTo(ConsumerStates.ShardConsumerState.PROCESSING))); // CHECKSTYLE:IGNORE ModifiedControlVariable FOR NEXT 1 LINES i += maxRecords; @@ -537,6 +565,8 @@ public class ShardConsumerTest { assertThat(processor.getShutdownReason(), is(equalTo(ShutdownReason.TERMINATE))); + verify(getRecordsCache).shutdown(); + executorService.shutdown(); executorService.awaitTermination(60, TimeUnit.SECONDS); @@ -594,7 +624,7 @@ public class ShardConsumerTest { ) ); - KinesisDataFetcher dataFetcher = new KinesisDataFetcher(streamConfig.getStreamProxy(), shardInfo); + dataFetcher = new KinesisDataFetcher(streamConfig.getStreamProxy(), shardInfo); getRecordsCache = spy(new BlockingGetRecordsCache(maxRecords, new SynchronousGetRecordsRetrievalStrategy(dataFetcher)));