Code cleanup to introduce better testing and simplify future removal of (#1094)

deprecated parameters (e.g., `Either<L, R> appStreamTracker`).
This commit is contained in:
stair 2023-04-18 14:58:27 -04:00 committed by GitHub
parent 7b23ae9b3c
commit 5e7d4788ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 324 additions and 199 deletions

View file

@ -16,11 +16,13 @@
package software.amazon.kinesis.common;
import lombok.Data;
import lombok.NonNull;
import lombok.experimental.Accessors;
@Data
@Accessors(fluent = true)
public class StreamConfig {
@NonNull
private final StreamIdentifier streamIdentifier;
private final InitialPositionInStreamExtended initialPositionInStreamExtended;
private String consumerArn;

View file

@ -15,9 +15,9 @@
package software.amazon.kinesis.processor;
import lombok.Data;
import lombok.NonNull;
import lombok.experimental.Accessors;
import lombok.Data;
import lombok.NonNull;
import lombok.experimental.Accessors;
/**
* Used by the KCL to configure the processor for processing the records.

View file

@ -133,6 +133,8 @@ public class RetrievalConfig {
}
/**
* Convenience method to reconfigure the embedded {@link StreamTracker},
* but only when <b>not</b> in multi-stream mode.
*
* @param initialPositionInStreamExtended
*
@ -142,62 +144,46 @@ public class RetrievalConfig {
*/
@Deprecated
public RetrievalConfig initialPositionInStreamExtended(InitialPositionInStreamExtended initialPositionInStreamExtended) {
this.appStreamTracker.apply(multiStreamTracker -> {
if (streamTracker().isMultiStream()) {
throw new IllegalArgumentException(
"Cannot set initialPositionInStreamExtended when multiStreamTracker is set");
}, sc -> {
final StreamConfig updatedConfig = new StreamConfig(sc.streamIdentifier(), initialPositionInStreamExtended);
streamTracker = new SingleStreamTracker(sc.streamIdentifier(), updatedConfig);
appStreamTracker = Either.right(updatedConfig);
});
};
final StreamIdentifier streamIdentifier = getSingleStreamIdentifier();
final StreamConfig updatedConfig = new StreamConfig(streamIdentifier, initialPositionInStreamExtended);
streamTracker = new SingleStreamTracker(streamIdentifier, updatedConfig);
appStreamTracker = Either.right(updatedConfig);
return this;
}
public RetrievalConfig retrievalSpecificConfig(RetrievalSpecificConfig retrievalSpecificConfig) {
retrievalSpecificConfig.validateState(streamTracker.isMultiStream());
this.retrievalSpecificConfig = retrievalSpecificConfig;
validateFanoutConfig();
validatePollingConfig();
return this;
}
public RetrievalFactory retrievalFactory() {
if (retrievalFactory == null) {
if (retrievalSpecificConfig == null) {
retrievalSpecificConfig = new FanOutConfig(kinesisClient())
final FanOutConfig fanOutConfig = new FanOutConfig(kinesisClient())
.applicationName(applicationName());
retrievalSpecificConfig = appStreamTracker.map(multiStreamTracker -> retrievalSpecificConfig,
streamConfig -> ((FanOutConfig) retrievalSpecificConfig).streamName(streamConfig.streamIdentifier().streamName()));
if (!streamTracker.isMultiStream()) {
final String streamName = getSingleStreamIdentifier().streamName();
fanOutConfig.streamName(streamName);
}
retrievalSpecificConfig(fanOutConfig);
}
retrievalFactory = retrievalSpecificConfig.retrievalFactory();
}
return retrievalFactory;
}
private void validateFanoutConfig() {
// If we are in multistream mode and if retrievalSpecificConfig is an instance of FanOutConfig and if consumerArn is set throw exception.
boolean isFanoutConfig = retrievalSpecificConfig instanceof FanOutConfig;
boolean isInvalidFanoutConfig = isFanoutConfig && appStreamTracker.map(
multiStreamTracker -> ((FanOutConfig) retrievalSpecificConfig).consumerArn() != null
|| ((FanOutConfig) retrievalSpecificConfig).streamName() != null,
streamConfig -> streamConfig.streamIdentifier() == null
|| streamConfig.streamIdentifier().streamName() == null);
if(isInvalidFanoutConfig) {
throw new IllegalArgumentException(
"Invalid config: Either in multi-stream mode with streamName/consumerArn configured or in single-stream mode with no streamName configured");
}
/**
* Convenience method to return the {@link StreamIdentifier} from a
* single-stream tracker.
*/
private StreamIdentifier getSingleStreamIdentifier() {
return streamTracker.streamConfigList().get(0).streamIdentifier();
}
private void validatePollingConfig() {
boolean isPollingConfig = retrievalSpecificConfig instanceof PollingConfig;
boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map(
multiStreamTracker ->
((PollingConfig) retrievalSpecificConfig).streamName() != null,
streamConfig ->
streamConfig.streamIdentifier() == null || streamConfig.streamIdentifier().streamName() == null);
if (isInvalidPollingConfig) {
throw new IllegalArgumentException(
"Invalid config: Either in multi-stream mode with streamName configured or in single-stream mode with no streamName configured");
}
}
}

View file

@ -15,9 +15,6 @@
package software.amazon.kinesis.retrieval;
import java.util.function.Function;
import software.amazon.kinesis.retrieval.polling.DataFetcher;
public interface RetrievalSpecificConfig {
/**
* Creates and returns a retrieval factory for the specific configuration
@ -25,4 +22,23 @@ public interface RetrievalSpecificConfig {
* @return a retrieval factory that can create an appropriate retriever
*/
RetrievalFactory retrievalFactory();
/**
* Validates this instance is configured properly. For example, this
* method may validate that the stream name, if one is required, is
* non-null.
* <br/><br/>
* If not in a valid state, an informative unchecked Exception -- for
* example, an {@link IllegalArgumentException} -- should be thrown so
* the caller may rectify the misconfiguration.
*
* @param isMultiStream whether state should be validated for multi-stream
*
* @deprecated remove keyword `default` to force implementation-specific behavior
*/
@Deprecated
default void validateState(boolean isMultiStream) {
// TODO convert this to a non-default implementation in a "major" release
}
}

View file

@ -80,10 +80,21 @@ public class FanOutConfig implements RetrievalSpecificConfig {
*/
private long retryBackoffMillis = 1000;
@Override public RetrievalFactory retrievalFactory() {
@Override
public RetrievalFactory retrievalFactory() {
return new FanOutRetrievalFactory(kinesisClient, streamName, consumerArn, this::getOrCreateConsumerArn);
}
@Override
public void validateState(final boolean isMultiStream) {
if (isMultiStream) {
if ((streamName() != null) || (consumerArn() != null)) {
throw new IllegalArgumentException(
"FanOutConfig must not have streamName/consumerArn configured in multi-stream mode");
}
}
}
private String getOrCreateConsumerArn(String streamName) {
FanOutConsumerRegistration registration = createConsumerRegistration(streamName);
try {

View file

@ -143,4 +143,14 @@ public class PollingConfig implements RetrievalSpecificConfig {
return new SynchronousBlockingRetrievalFactory(streamName(), kinesisClient(), recordsFetcherFactory,
maxRecords(), kinesisRequestTimeout, dataFetcherProvider);
}
@Override
public void validateState(final boolean isMultiStream) {
if (isMultiStream) {
if (streamName() != null) {
throw new IllegalArgumentException(
"PollingConfig must not have streamName configured in multi-stream mode");
}
}
}
}

View file

@ -22,10 +22,10 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.mockito.Mockito.mock;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
@ -34,6 +34,7 @@ import software.amazon.kinesis.processor.ShardRecordProcessorFactory;
import software.amazon.kinesis.processor.SingleStreamTracker;
import software.amazon.kinesis.processor.StreamTracker;
@RunWith(MockitoJUnitRunner.class)
public class ConfigsBuilderTest {
@Mock
@ -51,11 +52,6 @@ public class ConfigsBuilderTest {
private static final String APPLICATION_NAME = ConfigsBuilderTest.class.getSimpleName();
private static final String WORKER_IDENTIFIER = "worker-id";
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
}
@Test
public void testTrackerConstruction() {
final String streamName = "single-stream";
@ -77,6 +73,7 @@ public class ConfigsBuilderTest {
}
private ConfigsBuilder createConfig(String streamName) {
// intentional invocation of constructor where streamName is a String
return new ConfigsBuilder(streamName, APPLICATION_NAME, mockKinesisClient, mockDynamoClient,
mockCloudWatchClient, WORKER_IDENTIFIER, mockShardProcessorFactory);
}

View file

@ -0,0 +1,14 @@
package software.amazon.kinesis.common;
import static software.amazon.kinesis.common.InitialPositionInStream.TRIM_HORIZON;
import org.junit.Test;
public class StreamConfigTest {
@Test(expected = NullPointerException.class)
public void testNullStreamIdentifier() {
new StreamConfig(null, InitialPositionInStreamExtended.newInitialPosition(TRIM_HORIZON));
}
}

View file

@ -15,13 +15,15 @@
package software.amazon.kinesis.lifecycle;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
@ -167,7 +169,7 @@ public class ShardConsumerTest {
@After
public void after() {
List<Runnable> remainder = executorService.shutdownNow();
assertThat(remainder.isEmpty(), equalTo(true));
assertTrue(remainder.isEmpty());
}
private class TestPublisher implements RecordsPublisher {
@ -267,8 +269,7 @@ public class ShardConsumerTest {
mockSuccessfulShutdown(null);
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(cache);
boolean initComplete = false;
while (!initComplete) {
@ -321,8 +322,7 @@ public class ShardConsumerTest {
mockSuccessfulShutdown(null);
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(cache);
boolean initComplete = false;
while (!initComplete) {
@ -341,7 +341,7 @@ public class ShardConsumerTest {
// This will block if a lock is held on ShardConsumer#this
//
consumer.executeLifecycle();
assertThat(consumer.isShutdown(), equalTo(false));
assertFalse(consumer.isShutdown());
log.debug("Release processing task interlock");
awaitAndResetBarrier(processingTaskInterlock);
@ -370,7 +370,6 @@ public class ShardConsumerTest {
@Test
public void testDataArrivesAfterProcessing2() throws Exception {
CyclicBarrier taskCallBarrier = new CyclicBarrier(2);
mockSuccessfulInitialize(null);
@ -380,8 +379,7 @@ public class ShardConsumerTest {
mockSuccessfulShutdown(null);
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
shardConsumerArgument, initialState, Function.identity(), 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(cache);
boolean initComplete = false;
while (!initComplete) {
@ -435,13 +433,10 @@ public class ShardConsumerTest {
verifyNoMoreInteractions(taskExecutionListener);
}
@SuppressWarnings("unchecked")
@Test
@Ignore
public final void testInitializationStateUponFailure() throws Exception {
ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, Function.identity(), 1,
taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(eq(shardConsumerArgument), eq(consumer), any())).thenReturn(initializeTask);
when(initializeTask.call()).thenReturn(new TaskResult(new Exception("Bad")));
@ -468,17 +463,14 @@ public class ShardConsumerTest {
/**
* Test method to verify consumer undergoes the transition WAITING_ON_PARENT_SHARDS -> INITIALIZING -> PROCESSING
*/
@SuppressWarnings("unchecked")
@Test
public final void testSuccessfulConsumerStateTransition() throws Exception {
public final void testSuccessfulConsumerStateTransition() {
ExecutorService directExecutorService = spy(executorService);
doAnswer(invocation -> directlyExecuteRunnable(invocation))
doAnswer(this::directlyExecuteRunnable)
.when(directExecutorService).execute(any());
ShardConsumer consumer = new ShardConsumer(recordsPublisher, directExecutorService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(directExecutorService, blockedOnParentsState);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
@ -502,20 +494,17 @@ public class ShardConsumerTest {
* Test method to verify consumer does not transition to PROCESSING from WAITING_ON_PARENT_SHARDS when
* INITIALIZING tasks gets rejected.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerNotTransitionsToProcessingWhenInitializationFails() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);
// Failing the initialization task and all other attempts after that.
doAnswer(invocation -> directlyExecuteRunnable(invocation))
doAnswer(this::directlyExecuteRunnable)
.doThrow(new RejectedExecutionException())
.when(failingService).execute(any());
@ -537,24 +526,21 @@ public class ShardConsumerTest {
* Test method to verify consumer transition to PROCESSING from WAITING_ON_PARENT_SHARDS with
* intermittent INITIALIZING task rejections.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerTransitionsToProcessingWithIntermittentInitializationFailures() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);
// Failing the initialization task and few other attempts after that.
doAnswer(invocation -> directlyExecuteRunnable(invocation))
doAnswer(this::directlyExecuteRunnable)
.doThrow(new RejectedExecutionException())
.doThrow(new RejectedExecutionException())
.doThrow(new RejectedExecutionException())
.doAnswer(invocation -> directlyExecuteRunnable(invocation))
.doAnswer(this::directlyExecuteRunnable)
.when(failingService).execute(any());
int arbitraryExecutionCount = 6;
@ -574,13 +560,10 @@ public class ShardConsumerTest {
/**
* Test method to verify consumer does not transition to INITIALIZING when WAITING_ON_PARENT_SHARDS task rejected.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerNotTransitionsToInitializingWhenWaitingOnParentsFails() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(failingService, blockedOnParentsState);
mockSuccessfulUnblockOnParentsWithFailureTransition();
mockSuccessfulInitializeWithFailureTransition();
@ -606,13 +589,10 @@ public class ShardConsumerTest {
/**
* Test method to verify consumer stays in INITIALIZING state when InitializationTask fails.
*/
@SuppressWarnings("unchecked")
@Test(expected = RejectedExecutionException.class)
public final void testInitializationStateUponSubmissionFailure() throws Exception {
ExecutorService failingService = mock(ExecutorService.class);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(failingService, initialState);
doThrow(new RejectedExecutionException()).when(failingService).execute(any());
@ -625,8 +605,7 @@ public class ShardConsumerTest {
@Test
public void testErrorThrowableInInitialization() throws Exception {
ShardConsumer consumer = new ShardConsumer(recordsPublisher, executorService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(recordsPublisher);
when(initialState.createTask(any(), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
@ -645,12 +624,10 @@ public class ShardConsumerTest {
@Test
public void testRequestedShutdownWhileQuiet() throws Exception {
CyclicBarrier taskBarrier = new CyclicBarrier(2);
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, logWarningForTaskAfterMillis,
shardConsumerArgument, initialState, t -> t, 1, taskExecutionListener, 0);
final ShardConsumer consumer = createShardConsumer(cache);
mockSuccessfulInitialize(null);
@ -692,15 +669,15 @@ public class ShardConsumerTest {
consumer.gracefulShutdown(shutdownNotification);
boolean shutdownComplete = consumer.shutdownComplete().get();
assertThat(shutdownComplete, equalTo(false));
assertFalse(shutdownComplete);
shutdownComplete = consumer.shutdownComplete().get();
assertThat(shutdownComplete, equalTo(false));
assertFalse(shutdownComplete);
consumer.leaseLost();
shutdownComplete = consumer.shutdownComplete().get();
assertThat(shutdownComplete, equalTo(false));
assertFalse(shutdownComplete);
shutdownComplete = consumer.shutdownComplete().get();
assertThat(shutdownComplete, equalTo(true));
assertTrue(shutdownComplete);
verify(processingState, times(2)).createTask(any(), any(), any());
verify(shutdownRequestedState, never()).shutdownTransition(eq(ShutdownReason.LEASE_LOST));
@ -776,7 +753,6 @@ public class ShardConsumerTest {
@Test
public void testLongRunningTasks() throws Exception {
TestPublisher cache = new TestPublisher();
ShardConsumer consumer = new ShardConsumer(cache, executorService, shardInfo, Optional.of(1L),
@ -792,19 +768,19 @@ public class ShardConsumerTest {
CompletableFuture<Boolean> initSuccess = consumer.initializeComplete();
awaitAndResetBarrier(taskArriveBarrier);
assertThat(consumer.taskRunningTime(), notNullValue());
assertNotNull(consumer.taskRunningTime());
consumer.healthCheck();
awaitAndResetBarrier(taskDepartBarrier);
assertThat(initSuccess.get(), equalTo(false));
assertFalse(initSuccess.get());
verify(initializeTask).call();
initSuccess = consumer.initializeComplete();
verify(initializeTask).call();
assertThat(initSuccess.get(), equalTo(true));
assertTrue(initSuccess.get());
consumer.healthCheck();
assertThat(consumer.taskRunningTime(), nullValue());
assertNull(consumer.taskRunningTime());
consumer.subscribe();
cache.awaitInitialSetup();
@ -813,14 +789,14 @@ public class ShardConsumerTest {
awaitAndResetBarrier(taskArriveBarrier);
Instant previousTaskStartTime = consumer.taskDispatchedAt();
assertThat(consumer.taskRunningTime(), notNullValue());
assertNotNull(consumer.taskRunningTime());
consumer.healthCheck();
awaitAndResetBarrier(taskDepartBarrier);
consumer.healthCheck();
cache.requestBarrier.await();
assertThat(consumer.taskRunningTime(), nullValue());
assertNull(consumer.taskRunningTime());
cache.requestBarrier.reset();
// Sleep for 10 millis before processing next task. If we don't; then the following
@ -831,28 +807,28 @@ public class ShardConsumerTest {
awaitAndResetBarrier(taskArriveBarrier);
Instant currentTaskStartTime = consumer.taskDispatchedAt();
assertThat(currentTaskStartTime, not(equalTo(previousTaskStartTime)));
assertNotEquals(currentTaskStartTime, previousTaskStartTime);
awaitAndResetBarrier(taskDepartBarrier);
cache.requestBarrier.await();
assertThat(consumer.taskRunningTime(), nullValue());
assertNull(consumer.taskRunningTime());
cache.requestBarrier.reset();
consumer.leaseLost();
assertThat(consumer.isShutdownRequested(), equalTo(true));
assertTrue(consumer.isShutdownRequested());
CompletableFuture<Boolean> shutdownComplete = consumer.shutdownComplete();
awaitAndResetBarrier(taskArriveBarrier);
assertThat(consumer.taskRunningTime(), notNullValue());
assertNotNull(consumer.taskRunningTime());
awaitAndResetBarrier(taskDepartBarrier);
assertThat(shutdownComplete.get(), equalTo(false));
assertFalse(shutdownComplete.get());
shutdownComplete = consumer.shutdownComplete();
assertThat(shutdownComplete.get(), equalTo(true));
assertTrue(shutdownComplete.get());
assertThat(consumer.taskRunningTime(), nullValue());
assertNull(consumer.taskRunningTime());
consumer.healthCheck();
verify(taskExecutionListener, times(1)).beforeTaskExecution(initialTaskInput);
@ -918,7 +894,6 @@ public class ShardConsumerTest {
}
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarrier taskInterlockBarrier) {
when(initialState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(initializeTask);
when(initialState.taskType()).thenReturn(TaskType.INITIALIZE);
when(initializeTask.taskType()).thenReturn(TaskType.INITIALIZE);
@ -968,4 +943,18 @@ public class ShardConsumerTest {
return null;
}
private ShardConsumer createShardConsumer(final RecordsPublisher publisher) {
return createShardConsumer(publisher, executorService, initialState);
}
private ShardConsumer createShardConsumer(final ExecutorService executorService, final ConsumerState state) {
return createShardConsumer(recordsPublisher, executorService, state);
}
private ShardConsumer createShardConsumer(final RecordsPublisher publisher,
final ExecutorService executorService, final ConsumerState state) {
return new ShardConsumer(publisher, executorService, shardInfo, logWarningForTaskAfterMillis,
shardConsumerArgument, state, Function.identity(), 1, taskExecutionListener, 0);
}
}

View file

@ -5,14 +5,19 @@ import java.util.Optional;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static software.amazon.kinesis.common.InitialPositionInStream.LATEST;
import static software.amazon.kinesis.common.InitialPositionInStream.TRIM_HORIZON;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.StreamConfig;
@ -20,6 +25,7 @@ import software.amazon.kinesis.processor.MultiStreamTracker;
import software.amazon.kinesis.processor.SingleStreamTracker;
import software.amazon.kinesis.processor.StreamTracker;
@RunWith(MockitoJUnitRunner.class)
public class RetrievalConfigTest {
private static final String APPLICATION_NAME = RetrievalConfigTest.class.getSimpleName();
@ -27,9 +33,12 @@ public class RetrievalConfigTest {
@Mock
private KinesisAsyncClient mockKinesisClient;
@Mock
private MultiStreamTracker mockMultiStreamTracker;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
when(mockMultiStreamTracker.isMultiStream()).thenReturn(true);
}
@Test
@ -69,11 +78,33 @@ public class RetrievalConfigTest {
@Test(expected = IllegalArgumentException.class)
public void testUpdateInitialPositionInMultiStream() {
final RetrievalConfig config = createConfig(mock(MultiStreamTracker.class));
config.initialPositionInStreamExtended(
createConfig(mockMultiStreamTracker).initialPositionInStreamExtended(
InitialPositionInStreamExtended.newInitialPosition(TRIM_HORIZON));
}
/**
* Test that an invalid {@link RetrievalSpecificConfig} does not overwrite
* a valid one.
*/
@Test
public void testInvalidRetrievalSpecificConfig() {
final RetrievalSpecificConfig validConfig = mock(RetrievalSpecificConfig.class);
final RetrievalSpecificConfig invalidConfig = mock(RetrievalSpecificConfig.class);
doThrow(new IllegalArgumentException("womp womp")).when(invalidConfig).validateState(true);
final RetrievalConfig config = createConfig(mockMultiStreamTracker);
assertNull(config.retrievalSpecificConfig());
config.retrievalSpecificConfig(validConfig);
assertEquals(validConfig, config.retrievalSpecificConfig());
try {
config.retrievalSpecificConfig(invalidConfig);
Assert.fail("should throw");
} catch (RuntimeException re) {
assertEquals(validConfig, config.retrievalSpecificConfig());
}
}
private RetrievalConfig createConfig(String streamName) {
return new RetrievalConfig(mockKinesisClient, streamName, APPLICATION_NAME);
}

View file

@ -15,16 +15,20 @@
package software.amazon.kinesis.retrieval.fanout;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -55,132 +59,150 @@ public class FanOutConfigTest {
@Mock
private StreamConfig streamConfig;
private FanOutConfig config;
@Before
public void setup() {
when(streamConfig.consumerArn()).thenReturn(null);
config = spy(new FanOutConfig(kinesisClient))
// DRY: set the most commonly-used parameters
.applicationName(TEST_APPLICATION_NAME)
.streamName(TEST_STREAM_NAME);
doReturn(consumerRegistration).when(config)
.createConsumerRegistration(eq(kinesisClient), anyString(), anyString());
}
@Test
public void testNoRegisterIfConsumerArnSet() throws Exception {
FanOutConfig config = new TestingConfig(kinesisClient).consumerArn(TEST_CONSUMER_ARN);
public void testNoRegisterIfConsumerArnSet() {
config.consumerArn(TEST_CONSUMER_ARN)
// unset common parameters
.applicationName(null).streamName(null);
RetrievalFactory retrievalFactory = config.retrievalFactory();
assertThat(retrievalFactory, not(nullValue()));
verify(consumerRegistration, never()).getOrCreateStreamConsumerArn();
assertNotNull(retrievalFactory);
verifyZeroInteractions(consumerRegistration);
}
@Test
public void testRegisterCalledWhenConsumerArnUnset() throws Exception {
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
.streamName(TEST_STREAM_NAME);
RetrievalFactory retrievalFactory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class);
// doReturn(Optional.of(StreamIdentifier.singleStreamInstance(TEST_STREAM_NAME).serialize())).when(shardInfo).streamIdentifier();
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(retrievalFactory, not(nullValue()));
getRecordsCache(null);
verify(consumerRegistration).getOrCreateStreamConsumerArn();
}
@Test
public void testRegisterNotCalledWhenConsumerArnSetInMultiStreamMode() throws Exception {
when(streamConfig.consumerArn()).thenReturn("consumerArn");
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
.streamName(TEST_STREAM_NAME);
RetrievalFactory retrievalFactory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt();
retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(retrievalFactory, not(nullValue()));
getRecordsCache("account:stream:12345");
verify(consumerRegistration, never()).getOrCreateStreamConsumerArn();
}
@Test
public void testRegisterCalledWhenConsumerArnNotSetInMultiStreamMode() throws Exception {
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
.streamName(TEST_STREAM_NAME);
RetrievalFactory retrievalFactory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt();
retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(retrievalFactory, not(nullValue()));
getRecordsCache("account:stream:12345");
verify(consumerRegistration).getOrCreateStreamConsumerArn();
}
@Test
public void testDependencyExceptionInConsumerCreation() throws Exception {
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
.streamName(TEST_STREAM_NAME);
DependencyException de = new DependencyException("Bad", null);
when(consumerRegistration.getOrCreateStreamConsumerArn()).thenThrow(de);
try {
config.retrievalFactory();
getRecordsCache(null);
Assert.fail("should throw");
} catch (RuntimeException e) {
verify(consumerRegistration).getOrCreateStreamConsumerArn();
assertThat(e.getCause(), equalTo(de));
assertEquals(de, e.getCause());
}
}
@Test
public void testCreationWithApplicationName() throws Exception {
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
.streamName(TEST_STREAM_NAME);
RetrievalFactory factory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(factory, not(nullValue()));
public void testCreationWithApplicationName() {
getRecordsCache(null);
TestingConfig testingConfig = (TestingConfig) config;
assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME));
assertThat(testingConfig.consumerToCreate, equalTo(TEST_APPLICATION_NAME));
assertEquals(TEST_STREAM_NAME, config.streamName());
assertEquals(TEST_APPLICATION_NAME, config.applicationName());
}
@Test
public void testCreationWithConsumerName() throws Exception {
FanOutConfig config = new TestingConfig(kinesisClient).consumerName(TEST_CONSUMER_NAME)
.streamName(TEST_STREAM_NAME);
RetrievalFactory factory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(factory, not(nullValue()));
TestingConfig testingConfig = (TestingConfig) config;
assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME));
assertThat(testingConfig.consumerToCreate, equalTo(TEST_CONSUMER_NAME));
public void testCreationWithConsumerName() {
config.consumerName(TEST_CONSUMER_NAME)
// unset common parameters
.applicationName(null);
getRecordsCache(null);
assertEquals(TEST_STREAM_NAME, config.streamName());
assertEquals(TEST_CONSUMER_NAME, config.consumerName());
}
@Test
public void testCreationWithBothConsumerApplication() throws Exception {
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
.consumerName(TEST_CONSUMER_NAME).streamName(TEST_STREAM_NAME);
RetrievalFactory factory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(factory, not(nullValue()));
public void testCreationWithBothConsumerApplication() {
config = config.consumerName(TEST_CONSUMER_NAME);
TestingConfig testingConfig = (TestingConfig) config;
assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME));
assertThat(testingConfig.consumerToCreate, equalTo(TEST_CONSUMER_NAME));
getRecordsCache(null);
assertEquals(TEST_STREAM_NAME, config.streamName());
assertEquals(TEST_CONSUMER_NAME, config.consumerName());
}
private class TestingConfig extends FanOutConfig {
@Test
public void testValidState() {
assertNull(config.consumerArn());
assertNotNull(config.streamName());
String stream;
String consumerToCreate;
config.validateState(false);
public TestingConfig(KinesisAsyncClient kinesisClient) {
super(kinesisClient);
// both streamName and consumerArn are non-null
config.consumerArn(TEST_CONSUMER_ARN);
config.validateState(false);
config.consumerArn(null);
config.streamName(null);
config.validateState(false);
config.validateState(true);
assertNull(config.streamName());
assertNull(config.consumerArn());
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidStateMultiWithStreamName() {
testInvalidState(TEST_STREAM_NAME, null);
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidStateMultiWithConsumerArn() {
testInvalidState(null, TEST_CONSUMER_ARN);
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidStateMultiWithStreamNameAndConsumerArn() {
testInvalidState(TEST_STREAM_NAME, TEST_CONSUMER_ARN);
}
private void testInvalidState(final String streamName, final String consumerArn) {
config.streamName(streamName);
config.consumerArn(consumerArn);
try {
config.validateState(true);
} finally {
assertEquals(streamName, config.streamName());
assertEquals(consumerArn, config.consumerArn());
}
}
@Override
protected FanOutConsumerRegistration createConsumerRegistration(KinesisAsyncClient client, String stream,
String consumerToCreate) {
this.stream = stream;
this.consumerToCreate = consumerToCreate;
return consumerRegistration;
}
private void getRecordsCache(final String streamIdentifer) {
final ShardInfo shardInfo = mock(ShardInfo.class);
when(shardInfo.streamIdentifierSerOpt()).thenReturn(Optional.ofNullable(streamIdentifer));
final RetrievalFactory factory = config.retrievalFactory();
factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
}
}

View file

@ -0,0 +1,47 @@
package software.amazon.kinesis.retrieval.polling;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
@RunWith(MockitoJUnitRunner.class)
public class PollingConfigTest {
private static final String STREAM_NAME = PollingConfigTest.class.getSimpleName();
@Mock
private KinesisAsyncClient mockKinesisClinet;
private PollingConfig config;
@Before
public void setUp() {
config = new PollingConfig(mockKinesisClinet);
}
@Test
public void testValidState() {
assertNull(config.streamName());
config.validateState(true);
config.validateState(false);
config.streamName(STREAM_NAME);
config.validateState(false);
assertEquals(STREAM_NAME, config.streamName());
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidStateMultiWithStreamName() {
config.streamName(STREAM_NAME);
config.validateState(true);
}
}