Fix for invalid ShardConsumer state transitions due to rejected executions (#560)

* Fix to prevent ShardConsumer state transition, when the source state task execution is rejected by the executor service.

* Unit test case improvements

* Optimized imports

* Removed unnecessary sleep in unit test case

* Fixing imports

* Fixing import again with wildcard removed

* Adding asserts to exception cases in SharConsumerTest
This commit is contained in:
ashwing 2019-07-08 16:30:27 -07:00 committed by Sahil Palvia
parent b6236d8077
commit 9e2d6fa497
2 changed files with 187 additions and 0 deletions

View file

@ -20,6 +20,7 @@ import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import org.reactivestreams.Subscription;
@ -164,6 +165,9 @@ public class ShardConsumer {
} else if (needsInitialization) {
if (stateChangeFuture != null) {
if (stateChangeFuture.get()) {
// Task rejection during the subscribe() call will not be propagated back as it not executed
// in the context of the Scheduler thread. Hence we should not assume the subscription will
// always be successful.
subscribe();
needsInitialization = false;
}
@ -177,6 +181,11 @@ public class ShardConsumer {
//
} catch (ExecutionException e) {
throw new RuntimeException(e);
} catch (RejectedExecutionException e) {
// It is possible the tasks submitted to the executor service by the Scheduler thread might get rejected
// due to various reasons. Such failed executions must be captured and marked as failure to prevent
// the state transitions.
taskOutcome = TaskOutcome.FAILURE;
}
if (ConsumerStates.ShardConsumerState.PROCESSING.equals(currentState.state())) {

View file

@ -20,7 +20,9 @@ import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
@ -28,6 +30,7 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -58,6 +61,7 @@ import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
@ -69,6 +73,7 @@ import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
import software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState;
import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.RecordsRetrieved;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
@ -95,6 +100,10 @@ public class ShardConsumerTest {
@Mock
private ShutdownNotification shutdownNotification;
@Mock
private ConsumerState blockedOnParentsState;
@Mock
private ConsumerTask blockedOnParentsTask;
@Mock
private ConsumerState initialState;
@Mock
private ConsumerTask initializeTask;
@ -111,6 +120,8 @@ public class ShardConsumerTest {
@Mock
private TaskResult processingTaskResult;
@Mock
private TaskResult blockOnParentsTaskResult;
@Mock
private ConsumerState shutdownCompleteState;
@Mock
private ShardConsumerArgument shardConsumerArgument;
@ -441,6 +452,144 @@ public class ShardConsumerTest {
verify(initialState, never()).shutdownTransition(any());
}
/**
* Test method to verify consumer undergoes the transition WAITING_ON_PARENT_SHARDS -> INITIALIZING -> PROCESSING
*/
@SuppressWarnings("unchecked")
@Test
public final void testSuccessfulConsumerStateTransition() throws Exception {
ExecutorService directExecutorService = spy(executorService);
doAnswer(invocation -> directlyExecuteRunnable(invocation))
.when(directExecutorService).execute(any());
ShardConsumer consumer = new ShardConsumer(recordsPublisher, directExecutorService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);
int arbitraryExecutionCount = 3;
do {
try {
consumer.executeLifecycle();
} catch (Exception e) {
// Suppress any exception like the scheduler.
fail("Unexpected exception while executing consumer lifecycle");
}
} while (--arbitraryExecutionCount > 0);
assertEquals(ShardConsumerState.PROCESSING.consumerState().state(), consumer.currentState().state());
verify(directExecutorService, times(2)).execute(any());
}
/**
* Test method to verify consumer does not transition to PROCESSING from WAITING_ON_PARENT_SHARDS when
* INITIALIZING tasks gets rejected.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerNotTransitionsToProcessingWhenInitializationFails() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);
// Failing the initialization task and all other attempts after that.
doAnswer(invocation -> directlyExecuteRunnable(invocation))
.doThrow(new RejectedExecutionException())
.when(failingService).execute(any());
int arbitraryExecutionCount = 5;
do {
try {
consumer.executeLifecycle();
} catch (Exception e) {
// Suppress any exception like the scheduler.
fail("Unexpected exception while executing consumer lifecycle");
}
} while (--arbitraryExecutionCount > 0);
assertEquals(ShardConsumerState.INITIALIZING.consumerState().state(), consumer.currentState().state());
verify(failingService, times(5)).execute(any());
}
/**
* Test method to verify consumer transition to PROCESSING from WAITING_ON_PARENT_SHARDS with
* intermittent INITIALIZING task rejections.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerTransitionsToProcessingWithIntermittentInitializationFailures() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);
mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);
// Failing the initialization task and few other attempts after that.
doAnswer(invocation -> directlyExecuteRunnable(invocation))
.doThrow(new RejectedExecutionException())
.doThrow(new RejectedExecutionException())
.doThrow(new RejectedExecutionException())
.doAnswer(invocation -> directlyExecuteRunnable(invocation))
.when(failingService).execute(any());
int arbitraryExecutionCount = 6;
do {
try {
consumer.executeLifecycle();
} catch (Exception e) {
// Suppress any exception like the scheduler.
fail("Unexpected exception while executing consumer lifecycle");
}
} while (--arbitraryExecutionCount > 0);
assertEquals(ShardConsumerState.PROCESSING.consumerState().state(), consumer.currentState().state());
verify(failingService, times(5)).execute(any());
}
/**
* Test method to verify consumer does not transition to INITIALIZING when WAITING_ON_PARENT_SHARDS task rejected.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerNotTransitionsToInitializingWhenWaitingOnParentsFails() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);
mockSuccessfulUnblockOnParentsWithFailureTransition();
mockSuccessfulInitializeWithFailureTransition();
// Failing the waiting_on_parents task and few other attempts after that.
doThrow(new RejectedExecutionException())
.when(failingService).execute(any());
int arbitraryExecutionCount = 5;
do {
try {
consumer.executeLifecycle();
} catch (Exception e) {
// Suppress any exception like the scheduler.
fail("Unexpected exception while executing consumer lifecycle");
}
} while (--arbitraryExecutionCount > 0);
assertEquals(ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState().state(), consumer.currentState().state());
verify(failingService, times(5)).execute(any());
}
/**
* Test method to verify consumer stays in INITIALIZING state when InitializationTask fails.
*/
@ -742,6 +891,11 @@ public class ShardConsumerTest {
when(processingState.state()).thenReturn(ConsumerStates.ShardConsumerState.PROCESSING);
}
private void mockSuccessfulInitializeWithFailureTransition() {
mockSuccessfulInitialize(null, null);
when(initialState.failureTransition()).thenReturn(initialState);
}
private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier) {
mockSuccessfulInitialize(taskCallBarrier, null);
}
@ -763,6 +917,22 @@ public class ShardConsumerTest {
}
private void mockSuccessfulUnblockOnParentsWithFailureTransition() {
mockSuccessfulUnblockOnParents();
when(blockedOnParentsState.failureTransition()).thenReturn(blockedOnParentsState);
}
private void mockSuccessfulUnblockOnParents() {
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(blockedOnParentsTask);
when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
when(blockedOnParentsTask.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult);
when(blockOnParentsTaskResult.getException()).thenReturn(null);
when(blockedOnParentsState.requiresDataAvailability()).thenReturn(false);
when(blockedOnParentsState.successTransition()).thenReturn(initialState);
when(blockedOnParentsState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
}
private void awaitBarrier(CyclicBarrier barrier) throws Exception {
if (barrier != null) {
barrier.await();
@ -773,4 +943,12 @@ public class ShardConsumerTest {
barrier.await();
barrier.reset();
}
private Object directlyExecuteRunnable(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
Runnable runnable = (Runnable) args[0];
runnable.run();
return null;
}
}