Fix a race condition between ShardConsumer shutdown and initialization (#1319)

* Fix a race condition between ShardConsumer shutdown and initialization

When Kinesis shards have no data, there can be a race condition where
the shard-end record processing from RecordProcessorThread
interleaves with Scheduler performing initialization.
This leads to ShardConsumer making incorrect state transition
during initialization (moves from PROCESSING -> SHUTTING_DOWN) state
and during shutdown handling it moves from SHUTTING_DOWN -> SHUTDOWN_COMPLETE
without running the ShutdownTask.

This can cause the ShardConsumer to not perform proper shutdown
processing that is required for a child shard processing
to be unblocked. So the child shard could be blocked forever unless the
lease for the parent shard moves to a new worker and that worker does
not run into the race condition.

This patch fixes the race condition as follows:

The intializationComplete invocation is not needed after
needsInitialization has been set to false. Because initializationComplete
is mean to perform initialization in an async manner, but once
its done, the async task is a no-op in happy-path, but it can
perform incorrect state transition during a race condition.
This commit is contained in:
Aravinda Kidambi Srinivasan 2024-05-02 14:54:59 -07:00 committed by GitHub
parent 69cf5996c5
commit 16e8404dc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 129 additions and 0 deletions

View file

@ -179,6 +179,10 @@ public class ShardConsumer {
// Task rejection during the subscribe() call will not be propagated back as it not executed
// in the context of the Scheduler thread. Hence we should not assume the subscription will
// always be successful.
// But if subscription was not successful, then it will recover
// during healthCheck which will restart subscription.
// From Shardconsumer point of view, initialization after the below subscribe call
// is complete
subscribe();
needsInitialization = false;
}
@ -276,6 +280,16 @@ public class ShardConsumer {
@VisibleForTesting
synchronized CompletableFuture<Boolean> initializeComplete() {
if (!needsInitialization) {
// initialization already complete, this must be a no-op.
// ShardConsumer must be in ProcessingState and
// any further activity will be driven by publisher pushing data to subscriber
// which invokes handleInput and that triggers ProcessTask.
// Scheduler is only meant to do health-checks to ensure the consumer
// is not stuck for any reason and to do shutdown handling.
return CompletableFuture.completedFuture(true);
}
if (taskOutcome != null) {
updateState(taskOutcome);
}

View file

@ -32,7 +32,9 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -45,6 +47,7 @@ import java.util.List;
import java.util.Optional;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
@ -53,6 +56,7 @@ import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import org.junit.After;
@ -62,7 +66,9 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.reactivestreams.Subscriber;
@ -148,6 +154,7 @@ public class ShardConsumerTest {
@Before
public void before() {
MockitoAnnotations.initMocks(this);
shardInfo = new ShardInfo(shardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ThreadFactory factory = new ThreadFactoryBuilder().setNameFormat("test-" + testName.getMethodName() + "-%04d")
.setDaemon(true).build();
@ -848,6 +855,114 @@ public class ShardConsumerTest {
verifyNoMoreInteractions(taskExecutionListener);
}
@Test
public void testEmptyShardProcessingRaceCondition() throws Exception {
final RecordsPublisher mockPublisher = mock(RecordsPublisher.class);
final ExecutorService mockExecutor = mock(ExecutorService.class);
final ConsumerState mockState = mock(ConsumerState.class);
final ShardConsumer consumer = new ShardConsumer(mockPublisher, mockExecutor, shardInfo, Optional.of(1L),
shardConsumerArgument, mockState, Function.identity(), 1, taskExecutionListener, 0);
when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
final ConsumerTask mockTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockTask);
// Simulate successful BlockedOnParent task execution
// and successful Initialize task execution
when(mockTask.call()).thenReturn(new TaskResult(false));
log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to initiate async" +
" processing of blocked on parent task");
consumer.executeLifecycle();
final ArgumentCaptor<Runnable> taskToExecute = ArgumentCaptor.forClass(Runnable.class);
verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
taskToExecute.getValue().run();
log.info("RecordProcessor Thread: Simulated successful execution of Blocked on parent task");
reset(mockExecutor);
log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to move to InitializingState" +
" and initiate async processing of initialize task");
when(mockState.successTransition()).thenReturn(mockState);
when(mockState.state()).thenReturn(ShardConsumerState.INITIALIZING);
when(mockState.taskType()).thenReturn(TaskType.INITIALIZE);
consumer.executeLifecycle();
verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
log.info("RecordProcessor Thread: Simulated successful execution of Initialize task");
taskToExecute.getValue().run();
log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to move to ProcessingState" +
" and mark initialization future as complete");
when(mockState.state()).thenReturn(ShardConsumerState.PROCESSING);
consumer.executeLifecycle();
// Simulate the race where
// scheduler invokes executeLifecycle which performs Publisher.subscribe(subscriber)
// on recordProcessor thread
// but before scheduler thread finishes initialization, handleInput is invoked
// on record processor thread.
// Since ShardConsumer creates its own instance of subscriber that cannot be mocked
// this test sequence will appear a little odd.
// In order to control the order in which execution occurs, lets first invoke
// handleInput, although this will never happen, since there isn't a way
// to control the precise timing of the thread execution, this is the best way
final CountDownLatch processTaskLatch = new CountDownLatch(1);
new Thread(() -> {
reset(mockState);
when(mockState.taskType()).thenReturn(TaskType.PROCESS);
final ConsumerTask mockProcessTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask);
when(mockProcessTask.call()).then(input -> {
// first we want to wait for subscribe to be called,
// but we cannot control the timing, so wait for 10 seconds
// to let the main thread invoke executeLifecyle which
// will perform subscribe
processTaskLatch.countDown();
log.info("Record Processor Thread: Holding shardConsumer lock, waiting for 10 seconds to" +
" let subscribe be called by scheduler thread");
Thread.sleep(10 * 1000);
log.info("RecordProcessor Thread: Done waiting");
// then return shard end result
log.info("RecordProcessor Thread: Simulating execution of ProcessTask and returning shard-end result");
return new TaskResult(true);
});
final Subscription mockSubscription = mock(Subscription.class);
consumer.handleInput(ProcessRecordsInput.builder().isAtShardEnd(true).build(), mockSubscription);
}).start();
processTaskLatch.await();
// invoke executeLifecycle, which should invoke subscribe
// meanwhile if scheduler tries to acquire the ShardConsumer lock it will
// be blocked during initialization processing because handleInput was
// already invoked and will be holding the lock. Thereby creating the
// race condition we want.
reset(mockState);
AtomicBoolean successTransitionCalled = new AtomicBoolean(false);
when(mockState.successTransition()).then(input -> {
successTransitionCalled.set(true);
return mockState;
});
AtomicBoolean shutdownTransitionCalled = new AtomicBoolean(false);
when(mockState.shutdownTransition(any())).then(input -> {
shutdownTransitionCalled.set(true);
return mockState;
});
when(mockState.state()).then(input -> {
if (successTransitionCalled.get() && shutdownTransitionCalled.get()) {
return ShardConsumerState.SHUTTING_DOWN;
}
return ShardConsumerState.PROCESSING;
});
log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to invoke subscribe and" +
" complete initialization");
consumer.executeLifecycle();
log.info("Scheduler Thread: Done initializing the ShardConsumer");
log.info("Verifying scheduler did not perform shutdown transition during initialization");
verify(mockState, times(0)).shutdownTransition(any());
}
private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) {
mockSuccessfulShutdown(taskCallBarrier, null);
}