From 5e7d4788ecd6aaa7dda634ebd3169a605b7032eb Mon Sep 17 00:00:00 2001 From: stair <123031771+stair-aws@users.noreply.github.com> Date: Tue, 18 Apr 2023 14:58:27 -0400 Subject: [PATCH] Code cleanup to introduce better testing and simplify future removal of (#1094) deprecated parameters (e.g., `Either appStreamTracker`). --- .../amazon/kinesis/common/StreamConfig.java | 2 + .../kinesis/processor/ProcessorConfig.java | 6 +- .../kinesis/retrieval/RetrievalConfig.java | 58 +++--- .../retrieval/RetrievalSpecificConfig.java | 22 ++- .../retrieval/fanout/FanOutConfig.java | 13 +- .../retrieval/polling/PollingConfig.java | 10 + .../kinesis/common/ConfigsBuilderTest.java | 11 +- .../kinesis/common/StreamConfigTest.java | 14 ++ .../kinesis/lifecycle/ShardConsumerTest.java | 119 ++++++------ .../retrieval/RetrievalConfigTest.java | 39 +++- .../retrieval/fanout/FanOutConfigTest.java | 182 ++++++++++-------- .../retrieval/polling/PollingConfigTest.java | 47 +++++ 12 files changed, 324 insertions(+), 199 deletions(-) create mode 100644 amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamConfigTest.java create mode 100644 amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PollingConfigTest.java diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java index b1057f13..8ca75dec 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java @@ -16,11 +16,13 @@ package software.amazon.kinesis.common; import lombok.Data; +import lombok.NonNull; import lombok.experimental.Accessors; @Data @Accessors(fluent = true) public class StreamConfig { + @NonNull private final StreamIdentifier streamIdentifier; private final InitialPositionInStreamExtended initialPositionInStreamExtended; private String consumerArn; diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/processor/ProcessorConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/processor/ProcessorConfig.java index 04ea6614..7641bc44 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/processor/ProcessorConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/processor/ProcessorConfig.java @@ -15,9 +15,9 @@ package software.amazon.kinesis.processor; - import lombok.Data; - import lombok.NonNull; - import lombok.experimental.Accessors; +import lombok.Data; +import lombok.NonNull; +import lombok.experimental.Accessors; /** * Used by the KCL to configure the processor for processing the records. diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java index 3f001057..8ada4970 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java @@ -133,6 +133,8 @@ public class RetrievalConfig { } /** + * Convenience method to reconfigure the embedded {@link StreamTracker}, + * but only when not in multi-stream mode. * * @param initialPositionInStreamExtended * @@ -142,62 +144,46 @@ public class RetrievalConfig { */ @Deprecated public RetrievalConfig initialPositionInStreamExtended(InitialPositionInStreamExtended initialPositionInStreamExtended) { - this.appStreamTracker.apply(multiStreamTracker -> { + if (streamTracker().isMultiStream()) { throw new IllegalArgumentException( "Cannot set initialPositionInStreamExtended when multiStreamTracker is set"); - }, sc -> { - final StreamConfig updatedConfig = new StreamConfig(sc.streamIdentifier(), initialPositionInStreamExtended); - streamTracker = new SingleStreamTracker(sc.streamIdentifier(), updatedConfig); - appStreamTracker = Either.right(updatedConfig); - }); + }; + + final StreamIdentifier streamIdentifier = getSingleStreamIdentifier(); + final StreamConfig updatedConfig = new StreamConfig(streamIdentifier, initialPositionInStreamExtended); + streamTracker = new SingleStreamTracker(streamIdentifier, updatedConfig); + appStreamTracker = Either.right(updatedConfig); return this; } public RetrievalConfig retrievalSpecificConfig(RetrievalSpecificConfig retrievalSpecificConfig) { + retrievalSpecificConfig.validateState(streamTracker.isMultiStream()); this.retrievalSpecificConfig = retrievalSpecificConfig; - validateFanoutConfig(); - validatePollingConfig(); return this; } public RetrievalFactory retrievalFactory() { if (retrievalFactory == null) { if (retrievalSpecificConfig == null) { - retrievalSpecificConfig = new FanOutConfig(kinesisClient()) + final FanOutConfig fanOutConfig = new FanOutConfig(kinesisClient()) .applicationName(applicationName()); - retrievalSpecificConfig = appStreamTracker.map(multiStreamTracker -> retrievalSpecificConfig, - streamConfig -> ((FanOutConfig) retrievalSpecificConfig).streamName(streamConfig.streamIdentifier().streamName())); + if (!streamTracker.isMultiStream()) { + final String streamName = getSingleStreamIdentifier().streamName(); + fanOutConfig.streamName(streamName); + } + retrievalSpecificConfig(fanOutConfig); } retrievalFactory = retrievalSpecificConfig.retrievalFactory(); } return retrievalFactory; } - private void validateFanoutConfig() { - // If we are in multistream mode and if retrievalSpecificConfig is an instance of FanOutConfig and if consumerArn is set throw exception. - boolean isFanoutConfig = retrievalSpecificConfig instanceof FanOutConfig; - boolean isInvalidFanoutConfig = isFanoutConfig && appStreamTracker.map( - multiStreamTracker -> ((FanOutConfig) retrievalSpecificConfig).consumerArn() != null - || ((FanOutConfig) retrievalSpecificConfig).streamName() != null, - streamConfig -> streamConfig.streamIdentifier() == null - || streamConfig.streamIdentifier().streamName() == null); - if(isInvalidFanoutConfig) { - throw new IllegalArgumentException( - "Invalid config: Either in multi-stream mode with streamName/consumerArn configured or in single-stream mode with no streamName configured"); - } + /** + * Convenience method to return the {@link StreamIdentifier} from a + * single-stream tracker. + */ + private StreamIdentifier getSingleStreamIdentifier() { + return streamTracker.streamConfigList().get(0).streamIdentifier(); } - private void validatePollingConfig() { - boolean isPollingConfig = retrievalSpecificConfig instanceof PollingConfig; - boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map( - multiStreamTracker -> - ((PollingConfig) retrievalSpecificConfig).streamName() != null, - streamConfig -> - streamConfig.streamIdentifier() == null || streamConfig.streamIdentifier().streamName() == null); - - if (isInvalidPollingConfig) { - throw new IllegalArgumentException( - "Invalid config: Either in multi-stream mode with streamName configured or in single-stream mode with no streamName configured"); - } - } } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalSpecificConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalSpecificConfig.java index 30562994..d38fe054 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalSpecificConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalSpecificConfig.java @@ -15,9 +15,6 @@ package software.amazon.kinesis.retrieval; -import java.util.function.Function; -import software.amazon.kinesis.retrieval.polling.DataFetcher; - public interface RetrievalSpecificConfig { /** * Creates and returns a retrieval factory for the specific configuration @@ -25,4 +22,23 @@ public interface RetrievalSpecificConfig { * @return a retrieval factory that can create an appropriate retriever */ RetrievalFactory retrievalFactory(); + + /** + * Validates this instance is configured properly. For example, this + * method may validate that the stream name, if one is required, is + * non-null. + *

+ * If not in a valid state, an informative unchecked Exception -- for + * example, an {@link IllegalArgumentException} -- should be thrown so + * the caller may rectify the misconfiguration. + * + * @param isMultiStream whether state should be validated for multi-stream + * + * @deprecated remove keyword `default` to force implementation-specific behavior + */ + @Deprecated + default void validateState(boolean isMultiStream) { + // TODO convert this to a non-default implementation in a "major" release + } + } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java index 9318b996..16307377 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java @@ -80,10 +80,21 @@ public class FanOutConfig implements RetrievalSpecificConfig { */ private long retryBackoffMillis = 1000; - @Override public RetrievalFactory retrievalFactory() { + @Override + public RetrievalFactory retrievalFactory() { return new FanOutRetrievalFactory(kinesisClient, streamName, consumerArn, this::getOrCreateConsumerArn); } + @Override + public void validateState(final boolean isMultiStream) { + if (isMultiStream) { + if ((streamName() != null) || (consumerArn() != null)) { + throw new IllegalArgumentException( + "FanOutConfig must not have streamName/consumerArn configured in multi-stream mode"); + } + } + } + private String getOrCreateConsumerArn(String streamName) { FanOutConsumerRegistration registration = createConsumerRegistration(streamName); try { diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PollingConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PollingConfig.java index a37e7121..4dd64016 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PollingConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PollingConfig.java @@ -143,4 +143,14 @@ public class PollingConfig implements RetrievalSpecificConfig { return new SynchronousBlockingRetrievalFactory(streamName(), kinesisClient(), recordsFetcherFactory, maxRecords(), kinesisRequestTimeout, dataFetcherProvider); } + + @Override + public void validateState(final boolean isMultiStream) { + if (isMultiStream) { + if (streamName() != null) { + throw new IllegalArgumentException( + "PollingConfig must not have streamName configured in multi-stream mode"); + } + } + } } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/ConfigsBuilderTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/ConfigsBuilderTest.java index 8ea8f818..87caaa34 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/ConfigsBuilderTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/ConfigsBuilderTest.java @@ -22,10 +22,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.mockito.Mockito.mock; -import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.runners.MockitoJUnitRunner; import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; @@ -34,6 +34,7 @@ import software.amazon.kinesis.processor.ShardRecordProcessorFactory; import software.amazon.kinesis.processor.SingleStreamTracker; import software.amazon.kinesis.processor.StreamTracker; +@RunWith(MockitoJUnitRunner.class) public class ConfigsBuilderTest { @Mock @@ -51,11 +52,6 @@ public class ConfigsBuilderTest { private static final String APPLICATION_NAME = ConfigsBuilderTest.class.getSimpleName(); private static final String WORKER_IDENTIFIER = "worker-id"; - @Before - public void setUp() { - MockitoAnnotations.initMocks(this); - } - @Test public void testTrackerConstruction() { final String streamName = "single-stream"; @@ -77,6 +73,7 @@ public class ConfigsBuilderTest { } private ConfigsBuilder createConfig(String streamName) { + // intentional invocation of constructor where streamName is a String return new ConfigsBuilder(streamName, APPLICATION_NAME, mockKinesisClient, mockDynamoClient, mockCloudWatchClient, WORKER_IDENTIFIER, mockShardProcessorFactory); } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamConfigTest.java new file mode 100644 index 00000000..9ba3267d --- /dev/null +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamConfigTest.java @@ -0,0 +1,14 @@ +package software.amazon.kinesis.common; + +import static software.amazon.kinesis.common.InitialPositionInStream.TRIM_HORIZON; + +import org.junit.Test; + +public class StreamConfigTest { + + @Test(expected = NullPointerException.class) + public void testNullStreamIdentifier() { + new StreamConfig(null, InitialPositionInStreamExtended.newInitialPosition(TRIM_HORIZON)); + } + +} \ No newline at end of file diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java index 46677fb9..62fd13ef 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java @@ -15,13 +15,15 @@ package software.amazon.kinesis.lifecycle; -import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; @@ -167,7 +169,7 @@ public class ShardConsumerTest { @After public void after() { List remainder = executorService.shutdownNow(); - assertThat(remainder.isEmpty(), equalTo(true)); + assertTrue(remainder.isEmpty()); } private class TestPublisher implements RecordsPublisher { @@ -267,8 +269,7 @@ public class ShardConsumerTest { mockSuccessfulShutdown(null); TestPublisher cache = new TestPublisher(); - ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis, - shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(cache); boolean initComplete = false; while (!initComplete) { @@ -321,8 +322,7 @@ public class ShardConsumerTest { mockSuccessfulShutdown(null); TestPublisher cache = new TestPublisher(); - ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis, - shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(cache); boolean initComplete = false; while (!initComplete) { @@ -341,7 +341,7 @@ public class ShardConsumerTest { // This will block if a lock is held on ShardConsumer#this // consumer.executeLifecycle(); - assertThat(consumer.isShutdown(), equalTo(false)); + assertFalse(consumer.isShutdown()); log.debug("Release processing task interlock"); awaitAndResetBarrier(processingTaskInterlock); @@ -370,7 +370,6 @@ public class ShardConsumerTest { @Test public void testDataArrivesAfterProcessing2() throws Exception { - CyclicBarrier taskCallBarrier = new CyclicBarrier(2); mockSuccessfulInitialize(null); @@ -380,8 +379,7 @@ public class ShardConsumerTest { mockSuccessfulShutdown(null); TestPublisher cache = new TestPublisher(); - ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis, - shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(cache); boolean initComplete = false; while (!initComplete) { @@ -435,13 +433,10 @@ public class ShardConsumerTest { verifyNoMoreInteractions(taskExecutionListener); } - @SuppressWarnings("unchecked") @Test @Ignore public final void testInitializationStateUponFailure() throws Exception { - ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1, - taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(recordsPublisher); when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any())).thenReturn(initializeTask); when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad"))); @@ -468,17 +463,14 @@ public class ShardConsumerTest { /** * Test method to verify consumer undergoes the transition WAITING_ON_PARENT_SHARDS -> INITIALIZING -> PROCESSING */ - @SuppressWarnings("unchecked") @Test - public final void testSuccessfulConsumerStateTransition() throws Exception { + public final void testSuccessfulConsumerStateTransition() { ExecutorService directExecutorService = spy(executorService); - doAnswer(invocation -> directlyExecuteRunnable(invocation)) + doAnswer(this::directlyExecuteRunnable) .when(directExecutorService).execute(any()); - ShardConsumer consumer = new ShardConsumer(recordsPublisher, directExecutorService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState, - t -> t, 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(directExecutorService, blockedOnParentsState); mockSuccessfulUnblockOnParents(); mockSuccessfulInitializeWithFailureTransition(); @@ -502,20 +494,17 @@ public class ShardConsumerTest { * Test method to verify consumer does not transition to PROCESSING from WAITING_ON_PARENT_SHARDS when * INITIALIZING tasks gets rejected. */ - @SuppressWarnings("unchecked") @Test public final void testConsumerNotTransitionsToProcessingWhenInitializationFails() { ExecutorService failingService = spy(executorService); - ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState, - t -> t, 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState); mockSuccessfulUnblockOnParents(); mockSuccessfulInitializeWithFailureTransition(); mockSuccessfulProcessing(null); // Failing the initialization task and all other attempts after that. - doAnswer(invocation -> directlyExecuteRunnable(invocation)) + doAnswer(this::directlyExecuteRunnable) .doThrow(new RejectedExecutionException()) .when(failingService).execute(any()); @@ -537,24 +526,21 @@ public class ShardConsumerTest { * Test method to verify consumer transition to PROCESSING from WAITING_ON_PARENT_SHARDS with * intermittent INITIALIZING task rejections. */ - @SuppressWarnings("unchecked") @Test public final void testConsumerTransitionsToProcessingWithIntermittentInitializationFailures() { ExecutorService failingService = spy(executorService); - ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState, - t -> t, 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState); mockSuccessfulUnblockOnParents(); mockSuccessfulInitializeWithFailureTransition(); mockSuccessfulProcessing(null); // Failing the initialization task and few other attempts after that. - doAnswer(invocation -> directlyExecuteRunnable(invocation)) + doAnswer(this::directlyExecuteRunnable) .doThrow(new RejectedExecutionException()) .doThrow(new RejectedExecutionException()) .doThrow(new RejectedExecutionException()) - .doAnswer(invocation -> directlyExecuteRunnable(invocation)) + .doAnswer(this::directlyExecuteRunnable) .when(failingService).execute(any()); int arbitraryExecutionCount = 6; @@ -574,13 +560,10 @@ public class ShardConsumerTest { /** * Test method to verify consumer does not transition to INITIALIZING when WAITING_ON_PARENT_SHARDS task rejected. */ - @SuppressWarnings("unchecked") @Test public final void testConsumerNotTransitionsToInitializingWhenWaitingOnParentsFails() { ExecutorService failingService = spy(executorService); - ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState, - t -> t, 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState); mockSuccessfulUnblockOnParentsWithFailureTransition(); mockSuccessfulInitializeWithFailureTransition(); @@ -606,13 +589,10 @@ public class ShardConsumerTest { /** * Test method to verify consumer stays in INITIALIZING state when InitializationTask fails. */ - @SuppressWarnings("unchecked") @Test(expected = RejectedExecutionException.class) public final void testInitializationStateUponSubmissionFailure() throws Exception { - ExecutorService failingService = mock(ExecutorService.class); - ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(failingService, initialState); doThrow(new RejectedExecutionException()).when(failingService).execute(any()); @@ -625,8 +605,7 @@ public class ShardConsumerTest { @Test public void testErrorThrowableInInitialization() throws Exception { - ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo, - logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(recordsPublisher); when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask); when(initialState.taskType()).thenReturn(TaskType.INITIALIZE); @@ -645,12 +624,10 @@ public class ShardConsumerTest { @Test public void testRequestedShutdownWhileQuiet() throws Exception { - CyclicBarrier taskBarrier = new CyclicBarrier(2); TestPublisher cache = new TestPublisher(); - ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis, - shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0); + final ShardConsumer consumer = createShardConsumer(cache); mockSuccessfulInitialize(null); @@ -692,15 +669,15 @@ public class ShardConsumerTest { consumer.gracefulShutdown(shutdownNotification); boolean shutdownComplete = consumer.shutdownComplete().get(); - assertThat(shutdownComplete, equalTo(false)); + assertFalse(shutdownComplete); shutdownComplete = consumer.shutdownComplete().get(); - assertThat(shutdownComplete, equalTo(false)); + assertFalse(shutdownComplete); consumer.leaseLost(); shutdownComplete = consumer.shutdownComplete().get(); - assertThat(shutdownComplete, equalTo(false)); + assertFalse(shutdownComplete); shutdownComplete = consumer.shutdownComplete().get(); - assertThat(shutdownComplete, equalTo(true)); + assertTrue(shutdownComplete); verify(processingState, times(2)).createTask(any(), any(), any()); verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST)); @@ -776,7 +753,6 @@ public class ShardConsumerTest { @Test public void testLongRunningTasks() throws Exception { - TestPublisher cache = new TestPublisher(); ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L), @@ -792,19 +768,19 @@ public class ShardConsumerTest { CompletableFuture initSuccess = consumer.initializeComplete(); awaitAndResetBarrier(taskArriveBarrier); - assertThat(consumer.taskRunningTime(), notNullValue()); + assertNotNull(consumer.taskRunningTime()); consumer.healthCheck(); awaitAndResetBarrier(taskDepartBarrier); - assertThat(initSuccess.get(), equalTo(false)); + assertFalse(initSuccess.get()); verify(initializeTask).call(); initSuccess = consumer.initializeComplete(); verify(initializeTask).call(); - assertThat(initSuccess.get(), equalTo(true)); + assertTrue(initSuccess.get()); consumer.healthCheck(); - assertThat(consumer.taskRunningTime(), nullValue()); + assertNull(consumer.taskRunningTime()); consumer.subscribe(); cache.awaitInitialSetup(); @@ -813,14 +789,14 @@ public class ShardConsumerTest { awaitAndResetBarrier(taskArriveBarrier); Instant previousTaskStartTime = consumer.taskDispatchedAt(); - assertThat(consumer.taskRunningTime(), notNullValue()); + assertNotNull(consumer.taskRunningTime()); consumer.healthCheck(); awaitAndResetBarrier(taskDepartBarrier); consumer.healthCheck(); cache.requestBarrier.await(); - assertThat(consumer.taskRunningTime(), nullValue()); + assertNull(consumer.taskRunningTime()); cache.requestBarrier.reset(); // Sleep for 10 millis before processing next task. If we don't; then the following @@ -831,28 +807,28 @@ public class ShardConsumerTest { awaitAndResetBarrier(taskArriveBarrier); Instant currentTaskStartTime = consumer.taskDispatchedAt(); - assertThat(currentTaskStartTime, not(equalTo(previousTaskStartTime))); + assertNotEquals(currentTaskStartTime, previousTaskStartTime); awaitAndResetBarrier(taskDepartBarrier); cache.requestBarrier.await(); - assertThat(consumer.taskRunningTime(), nullValue()); + assertNull(consumer.taskRunningTime()); cache.requestBarrier.reset(); consumer.leaseLost(); - assertThat(consumer.isShutdownRequested(), equalTo(true)); + assertTrue(consumer.isShutdownRequested()); CompletableFuture shutdownComplete = consumer.shutdownComplete(); awaitAndResetBarrier(taskArriveBarrier); - assertThat(consumer.taskRunningTime(), notNullValue()); + assertNotNull(consumer.taskRunningTime()); awaitAndResetBarrier(taskDepartBarrier); - assertThat(shutdownComplete.get(), equalTo(false)); + assertFalse(shutdownComplete.get()); shutdownComplete = consumer.shutdownComplete(); - assertThat(shutdownComplete.get(), equalTo(true)); + assertTrue(shutdownComplete.get()); - assertThat(consumer.taskRunningTime(), nullValue()); + assertNull(consumer.taskRunningTime()); consumer.healthCheck(); verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput); @@ -918,7 +894,6 @@ public class ShardConsumerTest { } private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) { - when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask); when(initialState.taskType()).thenReturn(TaskType.INITIALIZE); when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE); @@ -968,4 +943,18 @@ public class ShardConsumerTest { return null; } + private ShardConsumer createShardConsumer(final RecordsPublisher publisher) { + return createShardConsumer(publisher, executorService, initialState); + } + + private ShardConsumer createShardConsumer(final ExecutorService executorService, final ConsumerState state) { + return createShardConsumer(recordsPublisher, executorService, state); + } + + private ShardConsumer createShardConsumer(final RecordsPublisher publisher, + final ExecutorService executorService, final ConsumerState state) { + return new ShardConsumer(publisher, executorService, shardInfo, logWarningForTaskAfterMillis, + shardConsumerArgument, state, Function.identity(), 1, taskExecutionListener, 0); + } + } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/RetrievalConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/RetrievalConfigTest.java index 041ac71e..464459d5 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/RetrievalConfigTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/RetrievalConfigTest.java @@ -5,14 +5,19 @@ import java.util.Optional; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static software.amazon.kinesis.common.InitialPositionInStream.LATEST; import static software.amazon.kinesis.common.InitialPositionInStream.TRIM_HORIZON; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.runners.MockitoJUnitRunner; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.StreamConfig; @@ -20,6 +25,7 @@ import software.amazon.kinesis.processor.MultiStreamTracker; import software.amazon.kinesis.processor.SingleStreamTracker; import software.amazon.kinesis.processor.StreamTracker; +@RunWith(MockitoJUnitRunner.class) public class RetrievalConfigTest { private static final String APPLICATION_NAME = RetrievalConfigTest.class.getSimpleName(); @@ -27,9 +33,12 @@ public class RetrievalConfigTest { @Mock private KinesisAsyncClient mockKinesisClient; + @Mock + private MultiStreamTracker mockMultiStreamTracker; + @Before public void setUp() { - MockitoAnnotations.initMocks(this); + when(mockMultiStreamTracker.isMultiStream()).thenReturn(true); } @Test @@ -69,11 +78,33 @@ public class RetrievalConfigTest { @Test(expected = IllegalArgumentException.class) public void testUpdateInitialPositionInMultiStream() { - final RetrievalConfig config = createConfig(mock(MultiStreamTracker.class)); - config.initialPositionInStreamExtended( + createConfig(mockMultiStreamTracker).initialPositionInStreamExtended( InitialPositionInStreamExtended.newInitialPosition(TRIM_HORIZON)); } + /** + * Test that an invalid {@link RetrievalSpecificConfig} does not overwrite + * a valid one. + */ + @Test + public void testInvalidRetrievalSpecificConfig() { + final RetrievalSpecificConfig validConfig = mock(RetrievalSpecificConfig.class); + final RetrievalSpecificConfig invalidConfig = mock(RetrievalSpecificConfig.class); + doThrow(new IllegalArgumentException("womp womp")).when(invalidConfig).validateState(true); + + final RetrievalConfig config = createConfig(mockMultiStreamTracker); + assertNull(config.retrievalSpecificConfig()); + config.retrievalSpecificConfig(validConfig); + assertEquals(validConfig, config.retrievalSpecificConfig()); + + try { + config.retrievalSpecificConfig(invalidConfig); + Assert.fail("should throw"); + } catch (RuntimeException re) { + assertEquals(validConfig, config.retrievalSpecificConfig()); + } + } + private RetrievalConfig createConfig(String streamName) { return new RetrievalConfig(mockKinesisClient, streamName, APPLICATION_NAME); } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java index 4fee3d08..32ca17ce 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java @@ -15,16 +15,20 @@ package software.amazon.kinesis.retrieval.fanout; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.CoreMatchers.nullValue; -import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -55,132 +59,150 @@ public class FanOutConfigTest { @Mock private StreamConfig streamConfig; + private FanOutConfig config; + @Before public void setup() { - when(streamConfig.consumerArn()).thenReturn(null); + config = spy(new FanOutConfig(kinesisClient)) + // DRY: set the most commonly-used parameters + .applicationName(TEST_APPLICATION_NAME) + .streamName(TEST_STREAM_NAME); + doReturn(consumerRegistration).when(config) + .createConsumerRegistration(eq(kinesisClient), anyString(), anyString()); } @Test - public void testNoRegisterIfConsumerArnSet() throws Exception { - FanOutConfig config = new TestingConfig(kinesisClient).consumerArn(TEST_CONSUMER_ARN); + public void testNoRegisterIfConsumerArnSet() { + config.consumerArn(TEST_CONSUMER_ARN) + // unset common parameters + .applicationName(null).streamName(null); + RetrievalFactory retrievalFactory = config.retrievalFactory(); - assertThat(retrievalFactory, not(nullValue())); - verify(consumerRegistration, never()).getOrCreateStreamConsumerArn(); + assertNotNull(retrievalFactory); + verifyZeroInteractions(consumerRegistration); } @Test public void testRegisterCalledWhenConsumerArnUnset() throws Exception { - FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) - .streamName(TEST_STREAM_NAME); - RetrievalFactory retrievalFactory = config.retrievalFactory(); - ShardInfo shardInfo = mock(ShardInfo.class); -// doReturn(Optional.of(StreamIdentifier.singleStreamInstance(TEST_STREAM_NAME).serialize())).when(shardInfo).streamIdentifier(); - doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); - retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); - assertThat(retrievalFactory, not(nullValue())); + getRecordsCache(null); + verify(consumerRegistration).getOrCreateStreamConsumerArn(); } @Test public void testRegisterNotCalledWhenConsumerArnSetInMultiStreamMode() throws Exception { when(streamConfig.consumerArn()).thenReturn("consumerArn"); - FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) - .streamName(TEST_STREAM_NAME); - RetrievalFactory retrievalFactory = config.retrievalFactory(); - ShardInfo shardInfo = mock(ShardInfo.class); - doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt(); - retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); - assertThat(retrievalFactory, not(nullValue())); + + getRecordsCache("account:stream:12345"); + verify(consumerRegistration, never()).getOrCreateStreamConsumerArn(); } @Test public void testRegisterCalledWhenConsumerArnNotSetInMultiStreamMode() throws Exception { - FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) - .streamName(TEST_STREAM_NAME); - RetrievalFactory retrievalFactory = config.retrievalFactory(); - ShardInfo shardInfo = mock(ShardInfo.class); - doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt(); - retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); - assertThat(retrievalFactory, not(nullValue())); + getRecordsCache("account:stream:12345"); + verify(consumerRegistration).getOrCreateStreamConsumerArn(); } @Test public void testDependencyExceptionInConsumerCreation() throws Exception { - FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) - .streamName(TEST_STREAM_NAME); DependencyException de = new DependencyException("Bad", null); when(consumerRegistration.getOrCreateStreamConsumerArn()).thenThrow(de); + try { - config.retrievalFactory(); + getRecordsCache(null); + Assert.fail("should throw"); } catch (RuntimeException e) { verify(consumerRegistration).getOrCreateStreamConsumerArn(); - assertThat(e.getCause(), equalTo(de)); + assertEquals(de, e.getCause()); } } @Test - public void testCreationWithApplicationName() throws Exception { - FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) - .streamName(TEST_STREAM_NAME); - RetrievalFactory factory = config.retrievalFactory(); - ShardInfo shardInfo = mock(ShardInfo.class); - doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); - factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); - assertThat(factory, not(nullValue())); + public void testCreationWithApplicationName() { + getRecordsCache(null); - TestingConfig testingConfig = (TestingConfig) config; - assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME)); - assertThat(testingConfig.consumerToCreate, equalTo(TEST_APPLICATION_NAME)); + assertEquals(TEST_STREAM_NAME, config.streamName()); + assertEquals(TEST_APPLICATION_NAME, config.applicationName()); } @Test - public void testCreationWithConsumerName() throws Exception { - FanOutConfig config = new TestingConfig(kinesisClient).consumerName(TEST_CONSUMER_NAME) - .streamName(TEST_STREAM_NAME); - RetrievalFactory factory = config.retrievalFactory(); - ShardInfo shardInfo = mock(ShardInfo.class); - doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); - factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); - assertThat(factory, not(nullValue())); - TestingConfig testingConfig = (TestingConfig) config; - assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME)); - assertThat(testingConfig.consumerToCreate, equalTo(TEST_CONSUMER_NAME)); + public void testCreationWithConsumerName() { + config.consumerName(TEST_CONSUMER_NAME) + // unset common parameters + .applicationName(null); + + getRecordsCache(null); + + assertEquals(TEST_STREAM_NAME, config.streamName()); + assertEquals(TEST_CONSUMER_NAME, config.consumerName()); } @Test - public void testCreationWithBothConsumerApplication() throws Exception { - FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) - .consumerName(TEST_CONSUMER_NAME).streamName(TEST_STREAM_NAME); - RetrievalFactory factory = config.retrievalFactory(); - ShardInfo shardInfo = mock(ShardInfo.class); - doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); - factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); - assertThat(factory, not(nullValue())); + public void testCreationWithBothConsumerApplication() { + config = config.consumerName(TEST_CONSUMER_NAME); - TestingConfig testingConfig = (TestingConfig) config; - assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME)); - assertThat(testingConfig.consumerToCreate, equalTo(TEST_CONSUMER_NAME)); + getRecordsCache(null); + + assertEquals(TEST_STREAM_NAME, config.streamName()); + assertEquals(TEST_CONSUMER_NAME, config.consumerName()); } - private class TestingConfig extends FanOutConfig { + @Test + public void testValidState() { + assertNull(config.consumerArn()); + assertNotNull(config.streamName()); - String stream; - String consumerToCreate; + config.validateState(false); - public TestingConfig(KinesisAsyncClient kinesisClient) { - super(kinesisClient); + // both streamName and consumerArn are non-null + config.consumerArn(TEST_CONSUMER_ARN); + config.validateState(false); + + config.consumerArn(null); + config.streamName(null); + config.validateState(false); + config.validateState(true); + + assertNull(config.streamName()); + assertNull(config.consumerArn()); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidStateMultiWithStreamName() { + testInvalidState(TEST_STREAM_NAME, null); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidStateMultiWithConsumerArn() { + testInvalidState(null, TEST_CONSUMER_ARN); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidStateMultiWithStreamNameAndConsumerArn() { + testInvalidState(TEST_STREAM_NAME, TEST_CONSUMER_ARN); + } + + private void testInvalidState(final String streamName, final String consumerArn) { + config.streamName(streamName); + config.consumerArn(consumerArn); + + try { + config.validateState(true); + } finally { + assertEquals(streamName, config.streamName()); + assertEquals(consumerArn, config.consumerArn()); } + } - @Override - protected FanOutConsumerRegistration createConsumerRegistration(KinesisAsyncClient client, String stream, - String consumerToCreate) { - this.stream = stream; - this.consumerToCreate = consumerToCreate; - return consumerRegistration; - } + private void getRecordsCache(final String streamIdentifer) { + final ShardInfo shardInfo = mock(ShardInfo.class); + when(shardInfo.streamIdentifierSerOpt()).thenReturn(Optional.ofNullable(streamIdentifer)); + + final RetrievalFactory factory = config.retrievalFactory(); + factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); } } \ No newline at end of file diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PollingConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PollingConfigTest.java new file mode 100644 index 00000000..760c6dce --- /dev/null +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PollingConfigTest.java @@ -0,0 +1,47 @@ +package software.amazon.kinesis.retrieval.polling; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; + +@RunWith(MockitoJUnitRunner.class) +public class PollingConfigTest { + + private static final String STREAM_NAME = PollingConfigTest.class.getSimpleName(); + + @Mock + private KinesisAsyncClient mockKinesisClinet; + + private PollingConfig config; + + @Before + public void setUp() { + config = new PollingConfig(mockKinesisClinet); + } + + @Test + public void testValidState() { + assertNull(config.streamName()); + + config.validateState(true); + config.validateState(false); + + config.streamName(STREAM_NAME); + config.validateState(false); + assertEquals(STREAM_NAME, config.streamName()); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidStateMultiWithStreamName() { + config.streamName(STREAM_NAME); + + config.validateState(true); + } + +} \ No newline at end of file