diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java
index b1057f13..8ca75dec 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java
@@ -16,11 +16,13 @@
package software.amazon.kinesis.common;
import lombok.Data;
+import lombok.NonNull;
import lombok.experimental.Accessors;
@Data
@Accessors(fluent = true)
public class StreamConfig {
+ @NonNull
private final StreamIdentifier streamIdentifier;
private final InitialPositionInStreamExtended initialPositionInStreamExtended;
private String consumerArn;
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/processor/ProcessorConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/processor/ProcessorConfig.java
index 04ea6614..7641bc44 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/processor/ProcessorConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/processor/ProcessorConfig.java
@@ -15,9 +15,9 @@
package software.amazon.kinesis.processor;
- import lombok.Data;
- import lombok.NonNull;
- import lombok.experimental.Accessors;
+import lombok.Data;
+import lombok.NonNull;
+import lombok.experimental.Accessors;
/**
* Used by the KCL to configure the processor for processing the records.
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java
index 3f001057..8ada4970 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java
@@ -133,6 +133,8 @@ public class RetrievalConfig {
}
/**
+ * Convenience method to reconfigure the embedded {@link StreamTracker},
+ * but only when not in multi-stream mode.
*
* @param initialPositionInStreamExtended
*
@@ -142,62 +144,46 @@ public class RetrievalConfig {
*/
@Deprecated
public RetrievalConfig initialPositionInStreamExtended(InitialPositionInStreamExtended initialPositionInStreamExtended) {
- this.appStreamTracker.apply(multiStreamTracker -> {
+ if (streamTracker().isMultiStream()) {
throw new IllegalArgumentException(
"Cannot set initialPositionInStreamExtended when multiStreamTracker is set");
- }, sc -> {
- final StreamConfig updatedConfig = new StreamConfig(sc.streamIdentifier(), initialPositionInStreamExtended);
- streamTracker = new SingleStreamTracker(sc.streamIdentifier(), updatedConfig);
- appStreamTracker = Either.right(updatedConfig);
- });
+ };
+
+ final StreamIdentifier streamIdentifier = getSingleStreamIdentifier();
+ final StreamConfig updatedConfig = new StreamConfig(streamIdentifier, initialPositionInStreamExtended);
+ streamTracker = new SingleStreamTracker(streamIdentifier, updatedConfig);
+ appStreamTracker = Either.right(updatedConfig);
return this;
}
public RetrievalConfig retrievalSpecificConfig(RetrievalSpecificConfig retrievalSpecificConfig) {
+ retrievalSpecificConfig.validateState(streamTracker.isMultiStream());
this.retrievalSpecificConfig = retrievalSpecificConfig;
- validateFanoutConfig();
- validatePollingConfig();
return this;
}
public RetrievalFactory retrievalFactory() {
if (retrievalFactory == null) {
if (retrievalSpecificConfig == null) {
- retrievalSpecificConfig = new FanOutConfig(kinesisClient())
+ final FanOutConfig fanOutConfig = new FanOutConfig(kinesisClient())
.applicationName(applicationName());
- retrievalSpecificConfig = appStreamTracker.map(multiStreamTracker -> retrievalSpecificConfig,
- streamConfig -> ((FanOutConfig) retrievalSpecificConfig).streamName(streamConfig.streamIdentifier().streamName()));
+ if (!streamTracker.isMultiStream()) {
+ final String streamName = getSingleStreamIdentifier().streamName();
+ fanOutConfig.streamName(streamName);
+ }
+ retrievalSpecificConfig(fanOutConfig);
}
retrievalFactory = retrievalSpecificConfig.retrievalFactory();
}
return retrievalFactory;
}
- private void validateFanoutConfig() {
- // If we are in multistream mode and if retrievalSpecificConfig is an instance of FanOutConfig and if consumerArn is set throw exception.
- boolean isFanoutConfig = retrievalSpecificConfig instanceof FanOutConfig;
- boolean isInvalidFanoutConfig = isFanoutConfig && appStreamTracker.map(
- multiStreamTracker -> ((FanOutConfig) retrievalSpecificConfig).consumerArn() != null
- || ((FanOutConfig) retrievalSpecificConfig).streamName() != null,
- streamConfig -> streamConfig.streamIdentifier() == null
- || streamConfig.streamIdentifier().streamName() == null);
- if(isInvalidFanoutConfig) {
- throw new IllegalArgumentException(
- "Invalid config: Either in multi-stream mode with streamName/consumerArn configured or in single-stream mode with no streamName configured");
- }
+ /**
+ * Convenience method to return the {@link StreamIdentifier} from a
+ * single-stream tracker.
+ */
+ private StreamIdentifier getSingleStreamIdentifier() {
+ return streamTracker.streamConfigList().get(0).streamIdentifier();
}
- private void validatePollingConfig() {
- boolean isPollingConfig = retrievalSpecificConfig instanceof PollingConfig;
- boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map(
- multiStreamTracker ->
- ((PollingConfig) retrievalSpecificConfig).streamName() != null,
- streamConfig ->
- streamConfig.streamIdentifier() == null || streamConfig.streamIdentifier().streamName() == null);
-
- if (isInvalidPollingConfig) {
- throw new IllegalArgumentException(
- "Invalid config: Either in multi-stream mode with streamName configured or in single-stream mode with no streamName configured");
- }
- }
}
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalSpecificConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalSpecificConfig.java
index 30562994..d38fe054 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalSpecificConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalSpecificConfig.java
@@ -15,9 +15,6 @@
package software.amazon.kinesis.retrieval;
-import java.util.function.Function;
-import software.amazon.kinesis.retrieval.polling.DataFetcher;
-
public interface RetrievalSpecificConfig {
/**
* Creates and returns a retrieval factory for the specific configuration
@@ -25,4 +22,23 @@ public interface RetrievalSpecificConfig {
* @return a retrieval factory that can create an appropriate retriever
*/
RetrievalFactory retrievalFactory();
+
+ /**
+ * Validates this instance is configured properly. For example, this
+ * method may validate that the stream name, if one is required, is
+ * non-null.
+ *
+ * If not in a valid state, an informative unchecked Exception -- for
+ * example, an {@link IllegalArgumentException} -- should be thrown so
+ * the caller may rectify the misconfiguration.
+ *
+ * @param isMultiStream whether state should be validated for multi-stream
+ *
+ * @deprecated remove keyword `default` to force implementation-specific behavior
+ */
+ @Deprecated
+ default void validateState(boolean isMultiStream) {
+ // TODO convert this to a non-default implementation in a "major" release
+ }
+
}
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
index 9318b996..16307377 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
@@ -80,10 +80,21 @@ public class FanOutConfig implements RetrievalSpecificConfig {
*/
private long retryBackoffMillis = 1000;
- @Override public RetrievalFactory retrievalFactory() {
+ @Override
+ public RetrievalFactory retrievalFactory() {
return new FanOutRetrievalFactory(kinesisClient, streamName, consumerArn, this::getOrCreateConsumerArn);
}
+ @Override
+ public void validateState(final boolean isMultiStream) {
+ if (isMultiStream) {
+ if ((streamName() != null) || (consumerArn() != null)) {
+ throw new IllegalArgumentException(
+ "FanOutConfig must not have streamName/consumerArn configured in multi-stream mode");
+ }
+ }
+ }
+
private String getOrCreateConsumerArn(String streamName) {
FanOutConsumerRegistration registration = createConsumerRegistration(streamName);
try {
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PollingConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PollingConfig.java
index a37e7121..4dd64016 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PollingConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PollingConfig.java
@@ -143,4 +143,14 @@ public class PollingConfig implements RetrievalSpecificConfig {
return new SynchronousBlockingRetrievalFactory(streamName(), kinesisClient(), recordsFetcherFactory,
maxRecords(), kinesisRequestTimeout, dataFetcherProvider);
}
+
+ @Override
+ public void validateState(final boolean isMultiStream) {
+ if (isMultiStream) {
+ if (streamName() != null) {
+ throw new IllegalArgumentException(
+ "PollingConfig must not have streamName configured in multi-stream mode");
+ }
+ }
+ }
}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/ConfigsBuilderTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/ConfigsBuilderTest.java
index 8ea8f818..87caaa34 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/ConfigsBuilderTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/ConfigsBuilderTest.java
@@ -22,10 +22,10 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.mockito.Mockito.mock;
-import org.junit.Before;
import org.junit.Test;
+import org.junit.runner.RunWith;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
@@ -34,6 +34,7 @@ import software.amazon.kinesis.processor.ShardRecordProcessorFactory;
import software.amazon.kinesis.processor.SingleStreamTracker;
import software.amazon.kinesis.processor.StreamTracker;
+@RunWith(MockitoJUnitRunner.class)
public class ConfigsBuilderTest {
@Mock
@@ -51,11 +52,6 @@ public class ConfigsBuilderTest {
private static final String APPLICATION_NAME = ConfigsBuilderTest.class.getSimpleName();
private static final String WORKER_IDENTIFIER = "worker-id";
- @Before
- public void setUp() {
- MockitoAnnotations.initMocks(this);
- }
-
@Test
public void testTrackerConstruction() {
final String streamName = "single-stream";
@@ -77,6 +73,7 @@ public class ConfigsBuilderTest {
}
private ConfigsBuilder createConfig(String streamName) {
+ // intentional invocation of constructor where streamName is a String
return new ConfigsBuilder(streamName, APPLICATION_NAME, mockKinesisClient, mockDynamoClient,
mockCloudWatchClient, WORKER_IDENTIFIER, mockShardProcessorFactory);
}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamConfigTest.java
new file mode 100644
index 00000000..9ba3267d
--- /dev/null
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamConfigTest.java
@@ -0,0 +1,14 @@
+package software.amazon.kinesis.common;
+
+import static software.amazon.kinesis.common.InitialPositionInStream.TRIM_HORIZON;
+
+import org.junit.Test;
+
+public class StreamConfigTest {
+
+ @Test(expected = NullPointerException.class)
+ public void testNullStreamIdentifier() {
+ new StreamConfig(null, InitialPositionInStreamExtended.newInitialPosition(TRIM_HORIZON));
+ }
+
+}
\ No newline at end of file
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
index 46677fb9..62fd13ef 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
@@ -15,13 +15,15 @@
package software.amazon.kinesis.lifecycle;
-import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
-import static org.hamcrest.Matchers.notNullValue;
-import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
@@ -167,7 +169,7 @@ public class ShardConsumerTest {
@After
public void after() {
List remainder = executorService.shutdownNow();
- assertThat(remainder.isEmpty(), equalTo(true));
+ assertTrue(remainder.isEmpty());
}
private class TestPublisher implements RecordsPublisher {
@@ -267,8 +269,7 @@ public class ShardConsumerTest {
mockSuccessfulShutdown(null);
TestPublisher cache = new TestPublisher();
- ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
- shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(cache);
boolean initComplete = false;
while (!initComplete) {
@@ -321,8 +322,7 @@ public class ShardConsumerTest {
mockSuccessfulShutdown(null);
TestPublisher cache = new TestPublisher();
- ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
- shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(cache);
boolean initComplete = false;
while (!initComplete) {
@@ -341,7 +341,7 @@ public class ShardConsumerTest {
// This will block if a lock is held on ShardConsumer#this
//
consumer.executeLifecycle();
- assertThat(consumer.isShutdown(), equalTo(false));
+ assertFalse(consumer.isShutdown());
log.debug("Release processing task interlock");
awaitAndResetBarrier(processingTaskInterlock);
@@ -370,7 +370,6 @@ public class ShardConsumerTest {
@Test
public void testDataArrivesAfterProcessing2() throws Exception {
-
CyclicBarrier taskCallBarrier = new CyclicBarrier(2);
mockSuccessfulInitialize(null);
@@ -380,8 +379,7 @@ public class ShardConsumerTest {
mockSuccessfulShutdown(null);
TestPublisher cache = new TestPublisher();
- ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
- shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(cache);
boolean initComplete = false;
while (!initComplete) {
@@ -435,13 +433,10 @@ public class ShardConsumerTest {
verifyNoMoreInteractions(taskExecutionListener);
}
- @SuppressWarnings("unchecked")
@Test
@Ignore
public final void testInitializationStateUponFailure() throws Exception {
- ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1,
- taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any())).thenReturn(initializeTask);
when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad")));
@@ -468,17 +463,14 @@ public class ShardConsumerTest {
/**
* Test method to verify consumer undergoes the transition WAITING_ON_PARENT_SHARDS -> INITIALIZING -> PROCESSING
*/
- @SuppressWarnings("unchecked")
@Test
- public final void testSuccessfulConsumerStateTransition() throws Exception {
+ public final void testSuccessfulConsumerStateTransition() {
ExecutorService directExecutorService = spy(executorService);
- doAnswer(invocation -> directlyExecuteRunnable(invocation))
+ doAnswer(this::directlyExecuteRunnable)
.when(directExecutorService).execute(any());
- ShardConsumer consumer = new ShardConsumer(recordsPublisher, directExecutorService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
- t -> t, 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(directExecutorService, blockedOnParentsState);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
@@ -502,20 +494,17 @@ public class ShardConsumerTest {
* Test method to verify consumer does not transition to PROCESSING from WAITING_ON_PARENT_SHARDS when
* INITIALIZING tasks gets rejected.
*/
- @SuppressWarnings("unchecked")
@Test
public final void testConsumerNotTransitionsToProcessingWhenInitializationFails() {
ExecutorService failingService = spy(executorService);
- ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
- t -> t, 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);
// Failing the initialization task and all other attempts after that.
- doAnswer(invocation -> directlyExecuteRunnable(invocation))
+ doAnswer(this::directlyExecuteRunnable)
.doThrow(new RejectedExecutionException())
.when(failingService).execute(any());
@@ -537,24 +526,21 @@ public class ShardConsumerTest {
* Test method to verify consumer transition to PROCESSING from WAITING_ON_PARENT_SHARDS with
* intermittent INITIALIZING task rejections.
*/
- @SuppressWarnings("unchecked")
@Test
public final void testConsumerTransitionsToProcessingWithIntermittentInitializationFailures() {
ExecutorService failingService = spy(executorService);
- ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
- t -> t, 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);
// Failing the initialization task and few other attempts after that.
- doAnswer(invocation -> directlyExecuteRunnable(invocation))
+ doAnswer(this::directlyExecuteRunnable)
.doThrow(new RejectedExecutionException())
.doThrow(new RejectedExecutionException())
.doThrow(new RejectedExecutionException())
- .doAnswer(invocation -> directlyExecuteRunnable(invocation))
+ .doAnswer(this::directlyExecuteRunnable)
.when(failingService).execute(any());
int arbitraryExecutionCount = 6;
@@ -574,13 +560,10 @@ public class ShardConsumerTest {
/**
* Test method to verify consumer does not transition to INITIALIZING when WAITING_ON_PARENT_SHARDS task rejected.
*/
- @SuppressWarnings("unchecked")
@Test
public final void testConsumerNotTransitionsToInitializingWhenWaitingOnParentsFails() {
ExecutorService failingService = spy(executorService);
- ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
- t -> t, 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState);
mockSuccessfulUnblockOnParentsWithFailureTransition();
mockSuccessfulInitializeWithFailureTransition();
@@ -606,13 +589,10 @@ public class ShardConsumerTest {
/**
* Test method to verify consumer stays in INITIALIZING state when InitializationTask fails.
*/
- @SuppressWarnings("unchecked")
@Test(expected = RejectedExecutionException.class)
public final void testInitializationStateUponSubmissionFailure() throws Exception {
-
ExecutorService failingService = mock(ExecutorService.class);
- ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(failingService, initialState);
doThrow(new RejectedExecutionException()).when(failingService).execute(any());
@@ -625,8 +605,7 @@ public class ShardConsumerTest {
@Test
public void testErrorThrowableInInitialization() throws Exception {
- ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo,
- logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
@@ -645,12 +624,10 @@ public class ShardConsumerTest {
@Test
public void testRequestedShutdownWhileQuiet() throws Exception {
-
CyclicBarrier taskBarrier = new CyclicBarrier(2);
TestPublisher cache = new TestPublisher();
- ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
- shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0);
+ final ShardConsumer consumer = createShardConsumer(cache);
mockSuccessfulInitialize(null);
@@ -692,15 +669,15 @@ public class ShardConsumerTest {
consumer.gracefulShutdown(shutdownNotification);
boolean shutdownComplete = consumer.shutdownComplete().get();
- assertThat(shutdownComplete, equalTo(false));
+ assertFalse(shutdownComplete);
shutdownComplete = consumer.shutdownComplete().get();
- assertThat(shutdownComplete, equalTo(false));
+ assertFalse(shutdownComplete);
consumer.leaseLost();
shutdownComplete = consumer.shutdownComplete().get();
- assertThat(shutdownComplete, equalTo(false));
+ assertFalse(shutdownComplete);
shutdownComplete = consumer.shutdownComplete().get();
- assertThat(shutdownComplete, equalTo(true));
+ assertTrue(shutdownComplete);
verify(processingState, times(2)).createTask(any(), any(), any());
verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
@@ -776,7 +753,6 @@ public class ShardConsumerTest {
@Test
public void testLongRunningTasks() throws Exception {
-
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L),
@@ -792,19 +768,19 @@ public class ShardConsumerTest {
CompletableFuture initSuccess = consumer.initializeComplete();
awaitAndResetBarrier(taskArriveBarrier);
- assertThat(consumer.taskRunningTime(), notNullValue());
+ assertNotNull(consumer.taskRunningTime());
consumer.healthCheck();
awaitAndResetBarrier(taskDepartBarrier);
- assertThat(initSuccess.get(), equalTo(false));
+ assertFalse(initSuccess.get());
verify(initializeTask).call();
initSuccess = consumer.initializeComplete();
verify(initializeTask).call();
- assertThat(initSuccess.get(), equalTo(true));
+ assertTrue(initSuccess.get());
consumer.healthCheck();
- assertThat(consumer.taskRunningTime(), nullValue());
+ assertNull(consumer.taskRunningTime());
consumer.subscribe();
cache.awaitInitialSetup();
@@ -813,14 +789,14 @@ public class ShardConsumerTest {
awaitAndResetBarrier(taskArriveBarrier);
Instant previousTaskStartTime = consumer.taskDispatchedAt();
- assertThat(consumer.taskRunningTime(), notNullValue());
+ assertNotNull(consumer.taskRunningTime());
consumer.healthCheck();
awaitAndResetBarrier(taskDepartBarrier);
consumer.healthCheck();
cache.requestBarrier.await();
- assertThat(consumer.taskRunningTime(), nullValue());
+ assertNull(consumer.taskRunningTime());
cache.requestBarrier.reset();
// Sleep for 10 millis before processing next task. If we don't; then the following
@@ -831,28 +807,28 @@ public class ShardConsumerTest {
awaitAndResetBarrier(taskArriveBarrier);
Instant currentTaskStartTime = consumer.taskDispatchedAt();
- assertThat(currentTaskStartTime, not(equalTo(previousTaskStartTime)));
+ assertNotEquals(currentTaskStartTime, previousTaskStartTime);
awaitAndResetBarrier(taskDepartBarrier);
cache.requestBarrier.await();
- assertThat(consumer.taskRunningTime(), nullValue());
+ assertNull(consumer.taskRunningTime());
cache.requestBarrier.reset();
consumer.leaseLost();
- assertThat(consumer.isShutdownRequested(), equalTo(true));
+ assertTrue(consumer.isShutdownRequested());
CompletableFuture shutdownComplete = consumer.shutdownComplete();
awaitAndResetBarrier(taskArriveBarrier);
- assertThat(consumer.taskRunningTime(), notNullValue());
+ assertNotNull(consumer.taskRunningTime());
awaitAndResetBarrier(taskDepartBarrier);
- assertThat(shutdownComplete.get(), equalTo(false));
+ assertFalse(shutdownComplete.get());
shutdownComplete = consumer.shutdownComplete();
- assertThat(shutdownComplete.get(), equalTo(true));
+ assertTrue(shutdownComplete.get());
- assertThat(consumer.taskRunningTime(), nullValue());
+ assertNull(consumer.taskRunningTime());
consumer.healthCheck();
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
@@ -918,7 +894,6 @@ public class ShardConsumerTest {
}
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
-
when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
@@ -968,4 +943,18 @@ public class ShardConsumerTest {
return null;
}
+ private ShardConsumer createShardConsumer(final RecordsPublisher publisher) {
+ return createShardConsumer(publisher, executorService, initialState);
+ }
+
+ private ShardConsumer createShardConsumer(final ExecutorService executorService, final ConsumerState state) {
+ return createShardConsumer(recordsPublisher, executorService, state);
+ }
+
+ private ShardConsumer createShardConsumer(final RecordsPublisher publisher,
+ final ExecutorService executorService, final ConsumerState state) {
+ return new ShardConsumer(publisher, executorService, shardInfo, logWarningForTaskAfterMillis,
+ shardConsumerArgument, state, Function.identity(), 1, taskExecutionListener, 0);
+ }
+
}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/RetrievalConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/RetrievalConfigTest.java
index 041ac71e..464459d5 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/RetrievalConfigTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/RetrievalConfigTest.java
@@ -5,14 +5,19 @@ import java.util.Optional;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
import static software.amazon.kinesis.common.InitialPositionInStream.LATEST;
import static software.amazon.kinesis.common.InitialPositionInStream.TRIM_HORIZON;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.junit.runner.RunWith;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.StreamConfig;
@@ -20,6 +25,7 @@ import software.amazon.kinesis.processor.MultiStreamTracker;
import software.amazon.kinesis.processor.SingleStreamTracker;
import software.amazon.kinesis.processor.StreamTracker;
+@RunWith(MockitoJUnitRunner.class)
public class RetrievalConfigTest {
private static final String APPLICATION_NAME = RetrievalConfigTest.class.getSimpleName();
@@ -27,9 +33,12 @@ public class RetrievalConfigTest {
@Mock
private KinesisAsyncClient mockKinesisClient;
+ @Mock
+ private MultiStreamTracker mockMultiStreamTracker;
+
@Before
public void setUp() {
- MockitoAnnotations.initMocks(this);
+ when(mockMultiStreamTracker.isMultiStream()).thenReturn(true);
}
@Test
@@ -69,11 +78,33 @@ public class RetrievalConfigTest {
@Test(expected = IllegalArgumentException.class)
public void testUpdateInitialPositionInMultiStream() {
- final RetrievalConfig config = createConfig(mock(MultiStreamTracker.class));
- config.initialPositionInStreamExtended(
+ createConfig(mockMultiStreamTracker).initialPositionInStreamExtended(
InitialPositionInStreamExtended.newInitialPosition(TRIM_HORIZON));
}
+ /**
+ * Test that an invalid {@link RetrievalSpecificConfig} does not overwrite
+ * a valid one.
+ */
+ @Test
+ public void testInvalidRetrievalSpecificConfig() {
+ final RetrievalSpecificConfig validConfig = mock(RetrievalSpecificConfig.class);
+ final RetrievalSpecificConfig invalidConfig = mock(RetrievalSpecificConfig.class);
+ doThrow(new IllegalArgumentException("womp womp")).when(invalidConfig).validateState(true);
+
+ final RetrievalConfig config = createConfig(mockMultiStreamTracker);
+ assertNull(config.retrievalSpecificConfig());
+ config.retrievalSpecificConfig(validConfig);
+ assertEquals(validConfig, config.retrievalSpecificConfig());
+
+ try {
+ config.retrievalSpecificConfig(invalidConfig);
+ Assert.fail("should throw");
+ } catch (RuntimeException re) {
+ assertEquals(validConfig, config.retrievalSpecificConfig());
+ }
+ }
+
private RetrievalConfig createConfig(String streamName) {
return new RetrievalConfig(mockKinesisClient, streamName, APPLICATION_NAME);
}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java
index 4fee3d08..32ca17ce 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java
@@ -15,16 +15,20 @@
package software.amazon.kinesis.retrieval.fanout;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.not;
-import static org.hamcrest.CoreMatchers.nullValue;
-import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -55,132 +59,150 @@ public class FanOutConfigTest {
@Mock
private StreamConfig streamConfig;
+ private FanOutConfig config;
+
@Before
public void setup() {
- when(streamConfig.consumerArn()).thenReturn(null);
+ config = spy(new FanOutConfig(kinesisClient))
+ // DRY: set the most commonly-used parameters
+ .applicationName(TEST_APPLICATION_NAME)
+ .streamName(TEST_STREAM_NAME);
+ doReturn(consumerRegistration).when(config)
+ .createConsumerRegistration(eq(kinesisClient), anyString(), anyString());
}
@Test
- public void testNoRegisterIfConsumerArnSet() throws Exception {
- FanOutConfig config = new TestingConfig(kinesisClient).consumerArn(TEST_CONSUMER_ARN);
+ public void testNoRegisterIfConsumerArnSet() {
+ config.consumerArn(TEST_CONSUMER_ARN)
+ // unset common parameters
+ .applicationName(null).streamName(null);
+
RetrievalFactory retrievalFactory = config.retrievalFactory();
- assertThat(retrievalFactory, not(nullValue()));
- verify(consumerRegistration, never()).getOrCreateStreamConsumerArn();
+ assertNotNull(retrievalFactory);
+ verifyZeroInteractions(consumerRegistration);
}
@Test
public void testRegisterCalledWhenConsumerArnUnset() throws Exception {
- FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
- .streamName(TEST_STREAM_NAME);
- RetrievalFactory retrievalFactory = config.retrievalFactory();
- ShardInfo shardInfo = mock(ShardInfo.class);
-// doReturn(Optional.of(StreamIdentifier.singleStreamInstance(TEST_STREAM_NAME).serialize())).when(shardInfo).streamIdentifier();
- doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
- retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
- assertThat(retrievalFactory, not(nullValue()));
+ getRecordsCache(null);
+
verify(consumerRegistration).getOrCreateStreamConsumerArn();
}
@Test
public void testRegisterNotCalledWhenConsumerArnSetInMultiStreamMode() throws Exception {
when(streamConfig.consumerArn()).thenReturn("consumerArn");
- FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
- .streamName(TEST_STREAM_NAME);
- RetrievalFactory retrievalFactory = config.retrievalFactory();
- ShardInfo shardInfo = mock(ShardInfo.class);
- doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt();
- retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
- assertThat(retrievalFactory, not(nullValue()));
+
+ getRecordsCache("account:stream:12345");
+
verify(consumerRegistration, never()).getOrCreateStreamConsumerArn();
}
@Test
public void testRegisterCalledWhenConsumerArnNotSetInMultiStreamMode() throws Exception {
- FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
- .streamName(TEST_STREAM_NAME);
- RetrievalFactory retrievalFactory = config.retrievalFactory();
- ShardInfo shardInfo = mock(ShardInfo.class);
- doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt();
- retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
- assertThat(retrievalFactory, not(nullValue()));
+ getRecordsCache("account:stream:12345");
+
verify(consumerRegistration).getOrCreateStreamConsumerArn();
}
@Test
public void testDependencyExceptionInConsumerCreation() throws Exception {
- FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
- .streamName(TEST_STREAM_NAME);
DependencyException de = new DependencyException("Bad", null);
when(consumerRegistration.getOrCreateStreamConsumerArn()).thenThrow(de);
+
try {
- config.retrievalFactory();
+ getRecordsCache(null);
+ Assert.fail("should throw");
} catch (RuntimeException e) {
verify(consumerRegistration).getOrCreateStreamConsumerArn();
- assertThat(e.getCause(), equalTo(de));
+ assertEquals(de, e.getCause());
}
}
@Test
- public void testCreationWithApplicationName() throws Exception {
- FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
- .streamName(TEST_STREAM_NAME);
- RetrievalFactory factory = config.retrievalFactory();
- ShardInfo shardInfo = mock(ShardInfo.class);
- doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
- factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
- assertThat(factory, not(nullValue()));
+ public void testCreationWithApplicationName() {
+ getRecordsCache(null);
- TestingConfig testingConfig = (TestingConfig) config;
- assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME));
- assertThat(testingConfig.consumerToCreate, equalTo(TEST_APPLICATION_NAME));
+ assertEquals(TEST_STREAM_NAME, config.streamName());
+ assertEquals(TEST_APPLICATION_NAME, config.applicationName());
}
@Test
- public void testCreationWithConsumerName() throws Exception {
- FanOutConfig config = new TestingConfig(kinesisClient).consumerName(TEST_CONSUMER_NAME)
- .streamName(TEST_STREAM_NAME);
- RetrievalFactory factory = config.retrievalFactory();
- ShardInfo shardInfo = mock(ShardInfo.class);
- doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
- factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
- assertThat(factory, not(nullValue()));
- TestingConfig testingConfig = (TestingConfig) config;
- assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME));
- assertThat(testingConfig.consumerToCreate, equalTo(TEST_CONSUMER_NAME));
+ public void testCreationWithConsumerName() {
+ config.consumerName(TEST_CONSUMER_NAME)
+ // unset common parameters
+ .applicationName(null);
+
+ getRecordsCache(null);
+
+ assertEquals(TEST_STREAM_NAME, config.streamName());
+ assertEquals(TEST_CONSUMER_NAME, config.consumerName());
}
@Test
- public void testCreationWithBothConsumerApplication() throws Exception {
- FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
- .consumerName(TEST_CONSUMER_NAME).streamName(TEST_STREAM_NAME);
- RetrievalFactory factory = config.retrievalFactory();
- ShardInfo shardInfo = mock(ShardInfo.class);
- doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
- factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
- assertThat(factory, not(nullValue()));
+ public void testCreationWithBothConsumerApplication() {
+ config = config.consumerName(TEST_CONSUMER_NAME);
- TestingConfig testingConfig = (TestingConfig) config;
- assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME));
- assertThat(testingConfig.consumerToCreate, equalTo(TEST_CONSUMER_NAME));
+ getRecordsCache(null);
+
+ assertEquals(TEST_STREAM_NAME, config.streamName());
+ assertEquals(TEST_CONSUMER_NAME, config.consumerName());
}
- private class TestingConfig extends FanOutConfig {
+ @Test
+ public void testValidState() {
+ assertNull(config.consumerArn());
+ assertNotNull(config.streamName());
- String stream;
- String consumerToCreate;
+ config.validateState(false);
- public TestingConfig(KinesisAsyncClient kinesisClient) {
- super(kinesisClient);
+ // both streamName and consumerArn are non-null
+ config.consumerArn(TEST_CONSUMER_ARN);
+ config.validateState(false);
+
+ config.consumerArn(null);
+ config.streamName(null);
+ config.validateState(false);
+ config.validateState(true);
+
+ assertNull(config.streamName());
+ assertNull(config.consumerArn());
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testInvalidStateMultiWithStreamName() {
+ testInvalidState(TEST_STREAM_NAME, null);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testInvalidStateMultiWithConsumerArn() {
+ testInvalidState(null, TEST_CONSUMER_ARN);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testInvalidStateMultiWithStreamNameAndConsumerArn() {
+ testInvalidState(TEST_STREAM_NAME, TEST_CONSUMER_ARN);
+ }
+
+ private void testInvalidState(final String streamName, final String consumerArn) {
+ config.streamName(streamName);
+ config.consumerArn(consumerArn);
+
+ try {
+ config.validateState(true);
+ } finally {
+ assertEquals(streamName, config.streamName());
+ assertEquals(consumerArn, config.consumerArn());
}
+ }
- @Override
- protected FanOutConsumerRegistration createConsumerRegistration(KinesisAsyncClient client, String stream,
- String consumerToCreate) {
- this.stream = stream;
- this.consumerToCreate = consumerToCreate;
- return consumerRegistration;
- }
+ private void getRecordsCache(final String streamIdentifer) {
+ final ShardInfo shardInfo = mock(ShardInfo.class);
+ when(shardInfo.streamIdentifierSerOpt()).thenReturn(Optional.ofNullable(streamIdentifer));
+
+ final RetrievalFactory factory = config.retrievalFactory();
+ factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
}
}
\ No newline at end of file
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PollingConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PollingConfigTest.java
new file mode 100644
index 00000000..760c6dce
--- /dev/null
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PollingConfigTest.java
@@ -0,0 +1,47 @@
+package software.amazon.kinesis.retrieval.polling;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
+
+@RunWith(MockitoJUnitRunner.class)
+public class PollingConfigTest {
+
+ private static final String STREAM_NAME = PollingConfigTest.class.getSimpleName();
+
+ @Mock
+ private KinesisAsyncClient mockKinesisClinet;
+
+ private PollingConfig config;
+
+ @Before
+ public void setUp() {
+ config = new PollingConfig(mockKinesisClinet);
+ }
+
+ @Test
+ public void testValidState() {
+ assertNull(config.streamName());
+
+ config.validateState(true);
+ config.validateState(false);
+
+ config.streamName(STREAM_NAME);
+ config.validateState(false);
+ assertEquals(STREAM_NAME, config.streamName());
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testInvalidStateMultiWithStreamName() {
+ config.streamName(STREAM_NAME);
+
+ config.validateState(true);
+ }
+
+}
\ No newline at end of file