Let healthchecks happen after initialization is complete

Also add a unit test to test the changes
This commit is contained in:
Aravinda Kidambi Srinivasan 2024-04-29 21:32:26 -07:00
parent bde5ae9dac
commit 940f93bdeb
2 changed files with 115 additions and 0 deletions

View file

@ -122,6 +122,13 @@
</dependency>
<!-- Test -->
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<version>3.0.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>

View file

@ -32,7 +32,9 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -45,6 +47,7 @@ import java.util.List;
import java.util.Optional;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
@ -53,8 +56,10 @@ import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import org.awaitility.Awaitility;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
@ -62,7 +67,9 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.reactivestreams.Subscriber;
@ -148,6 +155,7 @@ public class ShardConsumerTest {
@Before
public void before() {
MockitoAnnotations.initMocks(this);
shardInfo = new ShardInfo(shardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ThreadFactory factory = new ThreadFactoryBuilder().setNameFormat("test-" + testName.getMethodName() + "-%04d")
.setDaemon(true).build();
@ -848,6 +856,106 @@ public class ShardConsumerTest {
verifyNoMoreInteractions(taskExecutionListener);
}
@Test
public void testEmptyShardProcessingRaceCondition() throws Exception {
RecordsPublisher mockPublisher = mock(RecordsPublisher.class);
ExecutorService mockExecutor = mock(ExecutorService.class);
ConsumerState mockState = mock(ConsumerState.class);
ShardConsumer consumer = new ShardConsumer(mockPublisher, mockExecutor, shardInfo, Optional.of(1L),
shardConsumerArgument, mockState, Function.identity(), 1, taskExecutionListener, 0);
when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
ConsumerTask mockTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockTask);
when(mockTask.call()).thenReturn(new TaskResult(false));
// Invoke async processing of blocked on parent task
consumer.executeLifecycle();
ArgumentCaptor<Runnable> taskToExecute = ArgumentCaptor.forClass(Runnable.class);
verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
taskToExecute.getValue().run();
reset(mockExecutor);
// move to initializing state and
// Invoke async processing of initialize state
when(mockState.successTransition()).thenReturn(mockState);
when(mockState.state()).thenReturn(ShardConsumerState.INITIALIZING);
when(mockState.taskType()).thenReturn(TaskType.INITIALIZE);
consumer.executeLifecycle();
verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
taskToExecute.getValue().run();
// Move to processing state
// and complete initialization future successfully
when(mockState.state()).thenReturn(ShardConsumerState.PROCESSING);
consumer.executeLifecycle();
// Simulate the race where
// scheduler invokes executeLifecycle which performs Publisher.subscribe(subscriber)
// on recordProcessor thread
// but before scheduler thread finishes initialization, handleInput is invoked
// on record processor thread.
// Since ShardConsumer creates its own instance of subscriber that cannot be mocked
// this test sequence will appear a little odd.
// In order to control the order in which execution occurs, lets first invoke
// handleInput, although this will never happen, since there isn't a way
// to control the precise timing of the thread execution, this is the best way
CountDownLatch processTaskLatch = new CountDownLatch(1);
new Thread(() -> {
reset(mockState);
when(mockState.taskType()).thenReturn(TaskType.PROCESS);
ConsumerTask mockProcessTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask);
CountDownLatch waitForSubscribeLatch = new CountDownLatch(1);
when(mockProcessTask.call()).then(input -> {
// first we want to wait for subscribe to be called,
// but we cannot control the timing, so wait for 10 seconds
// to let the main thread invoke executeLifecyle which
// will perform subscribe
processTaskLatch.countDown();
log.info("Waiting for countdown latch");
waitForSubscribeLatch.await(10, TimeUnit.SECONDS);
log.info("Waiting for countdown latch - DONE");
// then return shard end result
return new TaskResult(true);
});
Subscription mockSubscription = mock(Subscription.class);
consumer.handleInput(ProcessRecordsInput.builder().isAtShardEnd(true).build(), mockSubscription);
}).start();
processTaskLatch.await();
// now invoke lifecycle which should invoke subscribe
// but since we cannot countdown the latch, the latch will timeout
// meanwhile if scheduler tries to acquire the ShardConsumer lock it will
// be blocked during initialization processing. Thereby creating the
// race condition we want.
reset(mockState);
AtomicBoolean successTransitionCalled = new AtomicBoolean(false);
when(mockState.successTransition()).then(input -> {
successTransitionCalled.set(true);
return mockState;
});
AtomicBoolean shutdownTransitionCalled = new AtomicBoolean(false);
when(mockState.shutdownTransition(any())).then(input -> {
shutdownTransitionCalled.set(true);
return mockState;
});
when(mockState.state()).then(input -> {
if (successTransitionCalled.get() && shutdownTransitionCalled.get()) {
return ShardConsumerState.SHUTTING_DOWN;
}
return ShardConsumerState.PROCESSING;
});
consumer.executeLifecycle();
// initialization should be done by now, make sure shard consumer did not
// perform shutdown processing yet.
verify(mockState, times(0)).shutdownTransition(any());
}
private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) {
mockSuccessfulShutdown(taskCallBarrier, null);
}