diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java index 21a9bdc3..5545fc03 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java @@ -20,6 +20,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; import org.reactivestreams.Subscription; @@ -164,6 +165,9 @@ public class ShardConsumer { } else if (needsInitialization) { if (stateChangeFuture != null) { if (stateChangeFuture.get()) { + // Task rejection during the subscribe() call will not be propagated back as it not executed + // in the context of the Scheduler thread. Hence we should not assume the subscription will + // always be successful. subscribe(); needsInitialization = false; } @@ -177,6 +181,11 @@ public class ShardConsumer { // } catch (ExecutionException e) { throw new RuntimeException(e); + } catch (RejectedExecutionException e) { + // It is possible the tasks submitted to the executor service by the Scheduler thread might get rejected + // due to various reasons. Such failed executions must be captured and marked as failure to prevent + // the state transitions. + taskOutcome = TaskOutcome.FAILURE; } if (ConsumerStates.ShardConsumerState.PROCESSING.equals(currentState.state())) { diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java index 7568cbde..320512e6 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java @@ -20,7 +20,9 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.eq; @@ -28,6 +30,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -58,6 +61,7 @@ import org.junit.Test; import org.junit.rules.TestName; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -69,6 +73,7 @@ import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput; +import software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState; import software.amazon.kinesis.retrieval.RecordsPublisher; import software.amazon.kinesis.retrieval.RecordsRetrieved; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; @@ -95,6 +100,10 @@ public class ShardConsumerTest { @Mock private ShutdownNotification shutdownNotification; @Mock + private ConsumerState blockedOnParentsState; + @Mock + private ConsumerTask blockedOnParentsTask; + @Mock private ConsumerState initialState; @Mock private ConsumerTask initializeTask; @@ -111,6 +120,8 @@ public class ShardConsumerTest { @Mock private TaskResult processingTaskResult; @Mock + private TaskResult blockOnParentsTaskResult; + @Mock private ConsumerState shutdownCompleteState; @Mock private ShardConsumerArgument shardConsumerArgument; @@ -441,6 +452,144 @@ public class ShardConsumerTest { verify(initialState, never()).shutdownTransition(any()); } + /** + * Test method to verify consumer undergoes the transition WAITING_ON_PARENT_SHARDS -> INITIALIZING -> PROCESSING + */ + @SuppressWarnings("unchecked") + @Test + public final void testSuccessfulConsumerStateTransition() throws Exception { + ExecutorService directExecutorService = spy(executorService); + + doAnswer(invocation -> directlyExecuteRunnable(invocation)) + .when(directExecutorService).execute(any()); + + ShardConsumer consumer = new ShardConsumer(recordsPublisher, directExecutorService, shardInfo, + logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState, + t -> t, 1, taskExecutionListener, 0); + + mockSuccessfulUnblockOnParents(); + mockSuccessfulInitializeWithFailureTransition(); + mockSuccessfulProcessing(null); + + int arbitraryExecutionCount = 3; + do { + try { + consumer.executeLifecycle(); + } catch (Exception e) { + // Suppress any exception like the scheduler. + fail("Unexpected exception while executing consumer lifecycle"); + } + } while (--arbitraryExecutionCount > 0); + + assertEquals(ShardConsumerState.PROCESSING.consumerState().state(), consumer.currentState().state()); + verify(directExecutorService, times(2)).execute(any()); + } + + /** + * Test method to verify consumer does not transition to PROCESSING from WAITING_ON_PARENT_SHARDS when + * INITIALIZING tasks gets rejected. + */ + @SuppressWarnings("unchecked") + @Test + public final void testConsumerNotTransitionsToProcessingWhenInitializationFails() { + ExecutorService failingService = spy(executorService); + ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo, + logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState, + t -> t, 1, taskExecutionListener, 0); + + mockSuccessfulUnblockOnParents(); + mockSuccessfulInitializeWithFailureTransition(); + mockSuccessfulProcessing(null); + + // Failing the initialization task and all other attempts after that. + doAnswer(invocation -> directlyExecuteRunnable(invocation)) + .doThrow(new RejectedExecutionException()) + .when(failingService).execute(any()); + + int arbitraryExecutionCount = 5; + do { + try { + consumer.executeLifecycle(); + } catch (Exception e) { + // Suppress any exception like the scheduler. + fail("Unexpected exception while executing consumer lifecycle"); + } + } while (--arbitraryExecutionCount > 0); + + assertEquals(ShardConsumerState.INITIALIZING.consumerState().state(), consumer.currentState().state()); + verify(failingService, times(5)).execute(any()); + } + + /** + * Test method to verify consumer transition to PROCESSING from WAITING_ON_PARENT_SHARDS with + * intermittent INITIALIZING task rejections. + */ + @SuppressWarnings("unchecked") + @Test + public final void testConsumerTransitionsToProcessingWithIntermittentInitializationFailures() { + ExecutorService failingService = spy(executorService); + ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo, + logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState, + t -> t, 1, taskExecutionListener, 0); + + mockSuccessfulUnblockOnParents(); + mockSuccessfulInitializeWithFailureTransition(); + mockSuccessfulProcessing(null); + + // Failing the initialization task and few other attempts after that. + doAnswer(invocation -> directlyExecuteRunnable(invocation)) + .doThrow(new RejectedExecutionException()) + .doThrow(new RejectedExecutionException()) + .doThrow(new RejectedExecutionException()) + .doAnswer(invocation -> directlyExecuteRunnable(invocation)) + .when(failingService).execute(any()); + + int arbitraryExecutionCount = 6; + do { + try { + consumer.executeLifecycle(); + } catch (Exception e) { + // Suppress any exception like the scheduler. + fail("Unexpected exception while executing consumer lifecycle"); + } + } while (--arbitraryExecutionCount > 0); + + assertEquals(ShardConsumerState.PROCESSING.consumerState().state(), consumer.currentState().state()); + verify(failingService, times(5)).execute(any()); + } + + /** + * Test method to verify consumer does not transition to INITIALIZING when WAITING_ON_PARENT_SHARDS task rejected. + */ + @SuppressWarnings("unchecked") + @Test + public final void testConsumerNotTransitionsToInitializingWhenWaitingOnParentsFails() { + ExecutorService failingService = spy(executorService); + ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo, + logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState, + t -> t, 1, taskExecutionListener, 0); + + mockSuccessfulUnblockOnParentsWithFailureTransition(); + mockSuccessfulInitializeWithFailureTransition(); + + // Failing the waiting_on_parents task and few other attempts after that. + doThrow(new RejectedExecutionException()) + .when(failingService).execute(any()); + + int arbitraryExecutionCount = 5; + do { + try { + consumer.executeLifecycle(); + } catch (Exception e) { + // Suppress any exception like the scheduler. + fail("Unexpected exception while executing consumer lifecycle"); + } + } while (--arbitraryExecutionCount > 0); + + assertEquals(ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState().state(), consumer.currentState().state()); + verify(failingService, times(5)).execute(any()); + } + /** * Test method to verify consumer stays in INITIALIZING state when InitializationTask fails. */ @@ -742,6 +891,11 @@ public class ShardConsumerTest { when(processingState.state()).thenReturn(ConsumerStates.ShardConsumerState.PROCESSING); } + private void mockSuccessfulInitializeWithFailureTransition() { + mockSuccessfulInitialize(null, null); + when(initialState.failureTransition()).thenReturn(initialState); + } + private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier) { mockSuccessfulInitialize(taskCallBarrier, null); } @@ -763,6 +917,22 @@ public class ShardConsumerTest { } + private void mockSuccessfulUnblockOnParentsWithFailureTransition() { + mockSuccessfulUnblockOnParents(); + when(blockedOnParentsState.failureTransition()).thenReturn(blockedOnParentsState); + } + + private void mockSuccessfulUnblockOnParents() { + when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(blockedOnParentsTask); + when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS); + when(blockedOnParentsTask.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS); + when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult); + when(blockOnParentsTaskResult.getException()).thenReturn(null); + when(blockedOnParentsState.requiresDataAvailability()).thenReturn(false); + when(blockedOnParentsState.successTransition()).thenReturn(initialState); + when(blockedOnParentsState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS); + } + private void awaitBarrier(CyclicBarrier barrier) throws Exception { if (barrier != null) { barrier.await(); @@ -773,4 +943,12 @@ public class ShardConsumerTest { barrier.await(); barrier.reset(); } + + private Object directlyExecuteRunnable(InvocationOnMock invocation) { + Object[] args = invocation.getArguments(); + Runnable runnable = (Runnable) args[0]; + runnable.run(); + return null; + } + }