Correcting the behavior of gracefulShutdown (#1302)

* modify ShutdownTask to call shutdownComplete for graceful shutdown

* add test to verify ShutdownTask succeeds regardless of shutdownNotification

* change access level for finalShutdownLatch to NONE

* remove unused variable in GracefulShutdownCoordinator

* make comment more concise

* move waitForFinalShutdown method into GracefulShutdownCoordinator class

* cleanup call method of GracefulShutdownCoordinator

* modify waitForFinalShutdown to throw InterruptedException
This commit is contained in:
vincentvilo-aws 2024-04-03 12:42:26 -07:00 committed by GitHub
parent 581d713815
commit 7f1f243676
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 142 additions and 30 deletions

View file

@ -14,22 +14,22 @@
*/
package software.amazon.kinesis.coordinator;
import lombok.Builder;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.concurrent.CountDownLatch;
@Data
@Builder
@Accessors(fluent = true)
class GracefulShutdownContext {
private final CountDownLatch shutdownCompleteLatch;
private final CountDownLatch notificationCompleteLatch;
private final CountDownLatch finalShutdownLatch;
private final Scheduler scheduler;
static GracefulShutdownContext SHUTDOWN_ALREADY_COMPLETED = new GracefulShutdownContext(null, null, null);
boolean isShutdownAlreadyCompleted() {
boolean isRecordProcessorShutdownComplete() {
return shutdownCompleteLatch == null && notificationCompleteLatch == null && scheduler == null;
}
}

View file

@ -23,6 +23,11 @@ import java.util.concurrent.TimeUnit;
class GracefulShutdownCoordinator {
/**
* arbitrary wait time for worker's finalShutdown
*/
private static final long FINAL_SHUTDOWN_WAIT_TIME_SECONDS = 60L;
CompletableFuture<Boolean> startGracefulShutdown(Callable<Boolean> shutdownCallable) {
CompletableFuture<Boolean> cf = new CompletableFuture<>();
CompletableFuture.runAsync(() -> {
@ -62,7 +67,18 @@ class GracefulShutdownCoordinator {
return String.format("Waiting for %d record processors to complete final shutdown", outstanding);
}
/**
* used to wait for the worker's final shutdown to complete before returning the future for graceful shutdown
* @return true if the final shutdown is successful, false otherwise.
*/
private boolean waitForFinalShutdown(GracefulShutdownContext context) throws InterruptedException {
return context.finalShutdownLatch().await(FINAL_SHUTDOWN_WAIT_TIME_SECONDS, TimeUnit.SECONDS);
}
private boolean waitForRecordProcessors(GracefulShutdownContext context) {
if (context.isRecordProcessorShutdownComplete()) {
return true;
}
//
// Awaiting for all ShardConsumer/RecordProcessors to be notified that a shutdown has been requested.
@ -148,14 +164,13 @@ class GracefulShutdownCoordinator {
@Override
public Boolean call() throws Exception {
GracefulShutdownContext context;
try {
context = startWorkerShutdown.call();
final GracefulShutdownContext context = startWorkerShutdown.call();
return waitForRecordProcessors(context) && waitForFinalShutdown(context);
} catch (Exception ex) {
log.warn("Caught exception while requesting initial worker shutdown.", ex);
throw ex;
}
return context.isShutdownAlreadyCompleted() || waitForRecordProcessors(context);
}
}
}

View file

@ -191,6 +191,14 @@ public class Scheduler implements Runnable {
* Used to ensure that only one requestedShutdown is in progress at a time.
*/
private CompletableFuture<Boolean> gracefulShutdownFuture;
/**
* CountDownLatch used by the GracefulShutdownCoordinator. Reaching zero means that
* the scheduler's finalShutdown() call has completed.
*/
@Getter(AccessLevel.NONE)
private final CountDownLatch finalShutdownLatch = new CountDownLatch(1);
@VisibleForTesting
protected boolean gracefuleShutdownStarted = false;
@ -797,7 +805,7 @@ public class Scheduler implements Runnable {
// If there are no leases notification is already completed, but we still need to shutdown the worker.
//
this.shutdown();
return GracefulShutdownContext.SHUTDOWN_ALREADY_COMPLETED;
return GracefulShutdownContext.builder().finalShutdownLatch(finalShutdownLatch).build();
}
CountDownLatch shutdownCompleteLatch = new CountDownLatch(leases.size());
CountDownLatch notificationCompleteLatch = new CountDownLatch(leases.size());
@ -818,7 +826,12 @@ public class Scheduler implements Runnable {
shutdownCompleteLatch.countDown();
}
}
return new GracefulShutdownContext(shutdownCompleteLatch, notificationCompleteLatch, this);
return GracefulShutdownContext.builder()
.shutdownCompleteLatch(shutdownCompleteLatch)
.notificationCompleteLatch(notificationCompleteLatch)
.finalShutdownLatch(finalShutdownLatch)
.scheduler(this)
.build();
};
}
@ -878,6 +891,7 @@ public class Scheduler implements Runnable {
((CloudWatchMetricsFactory) metricsFactory).shutdown();
}
shutdownComplete = true;
finalShutdownLatch.countDown();
}
private List<ShardInfo> getShardInfoForAssignments() {

View file

@ -478,6 +478,7 @@ class ConsumerStates {
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
consumer.shutdownNotification(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
@ -557,9 +558,6 @@ class ConsumerStates {
@Override
public ConsumerTask createTask(ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
if (consumer.shutdownNotification() != null) {
consumer.shutdownNotification().shutdownComplete();
}
return null;
}

View file

@ -87,6 +87,7 @@ public class ShutdownTask implements ConsumerTask {
private final ShardRecordProcessorCheckpointer recordProcessorCheckpointer;
@NonNull
private final ShutdownReason reason;
private final ShutdownNotification shutdownNotification;
@NonNull
private final InitialPositionInStreamExtended initialPositionInStream;
private final boolean cleanupLeasesOfCompletedShards;
@ -149,6 +150,12 @@ public class ShutdownTask implements ConsumerTask {
log.debug("Shutting down retrieval strategy for shard {}.", leaseKey);
recordsPublisher.shutdown();
// shutdownNotification is only set and used when gracefulShutdown starts
if (shutdownNotification != null) {
shutdownNotification.shutdownComplete();
}
log.debug("Record processor completed shutdown() for shard {}", leaseKey);
return new TaskResult(null);

View file

@ -45,6 +45,8 @@ public class GracefulShutdownCoordinatorTest {
@Mock
private CountDownLatch notificationCompleteLatch;
@Mock
private CountDownLatch finalShutdownLatch;
@Mock
private Scheduler scheduler;
@Mock
private Callable<GracefulShutdownContext> contextCallable;
@ -57,6 +59,7 @@ public class GracefulShutdownCoordinatorTest {
when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
assertThat(requestedShutdownCallable.call(), equalTo(true));
verify(shutdownCompleteLatch).await(anyLong(), any(TimeUnit.class));
@ -72,6 +75,7 @@ public class GracefulShutdownCoordinatorTest {
when(notificationCompleteLatch.getCount()).thenReturn(1L, 0L);
mockLatchAwait(shutdownCompleteLatch, true);
when(shutdownCompleteLatch.getCount()).thenReturn(1L, 1L, 0L);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(scheduler.shutdownComplete()).thenReturn(false, true);
mockShardInfoConsumerMap(1, 0);
@ -93,6 +97,7 @@ public class GracefulShutdownCoordinatorTest {
mockLatchAwait(notificationCompleteLatch, true);
mockLatchAwait(shutdownCompleteLatch, false, true);
when(shutdownCompleteLatch.getCount()).thenReturn(1L, 0L);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(scheduler.shutdownComplete()).thenReturn(false, true);
mockShardInfoConsumerMap(1, 0);
@ -117,6 +122,8 @@ public class GracefulShutdownCoordinatorTest {
mockLatchAwait(shutdownCompleteLatch, true);
when(shutdownCompleteLatch.getCount()).thenReturn(2L, 2L, 1L, 1L, 0L);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(scheduler.shutdownComplete()).thenReturn(false, false, false, true);
mockShardInfoConsumerMap(2, 1, 0);
@ -286,6 +293,44 @@ public class GracefulShutdownCoordinatorTest {
requestedShutdownCallable.call();
}
@Test
public void testShutdownFailsDueToRecordProcessors() throws Exception {
Callable<Boolean> requestedShutdownCallable = buildRequestedShutdownCallable();
when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(false);
when(shutdownCompleteLatch.getCount()).thenReturn(1L);
when(scheduler.shutdownComplete()).thenReturn(true);
mockShardInfoConsumerMap(1);
assertThat(requestedShutdownCallable.call(), equalTo(false));
verifyLatchAwait(shutdownCompleteLatch);
}
@Test
public void testShutdownFailsDueToWorker() throws Exception {
Callable<Boolean> requestedShutdownCallable = buildRequestedShutdownCallable();
when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(false);
assertThat(requestedShutdownCallable.call(), equalTo(false));
verifyLatchAwait(finalShutdownLatch);
}
/**
* tests that shutdown still succeeds in the case where there are no leases returned by the lease coordinator
*/
@Test
public void testShutdownSuccessWithNoLeases() throws Exception {
Callable<Boolean> requestedShutdownCallable = buildRequestedShutdownCallableWithNullLatches();
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
assertThat(requestedShutdownCallable.call(), equalTo(true));
verifyLatchAwait(finalShutdownLatch);
}
private void verifyLatchAwait(CountDownLatch latch) throws Exception {
verifyLatchAwait(latch, times(1));
}
@ -303,8 +348,24 @@ public class GracefulShutdownCoordinatorTest {
}
private Callable<Boolean> buildRequestedShutdownCallable() throws Exception {
GracefulShutdownContext context = new GracefulShutdownContext(shutdownCompleteLatch,
notificationCompleteLatch, scheduler);
GracefulShutdownContext context = GracefulShutdownContext.builder()
.shutdownCompleteLatch(shutdownCompleteLatch)
.notificationCompleteLatch(notificationCompleteLatch)
.finalShutdownLatch(finalShutdownLatch)
.scheduler(scheduler)
.build();
when(contextCallable.call()).thenReturn(context);
return new GracefulShutdownCoordinator().createGracefulShutdownCallable(contextCallable);
}
/**
* finalShutdownLatch will always be initialized, but shutdownCompleteLatch and notificationCompleteLatch are not
* initialized in the case where there are no leases returned by the lease coordinator
*/
private Callable<Boolean> buildRequestedShutdownCallableWithNullLatches() throws Exception {
GracefulShutdownContext context = GracefulShutdownContext.builder()
.finalShutdownLatch(finalShutdownLatch)
.build();
when(contextCallable.call()).thenReturn(context);
return new GracefulShutdownCoordinator().createGracefulShutdownCallable(contextCallable);
}
@ -319,4 +380,5 @@ public class GracefulShutdownCoordinatorTest {
when(shardInfoConsumerMap.isEmpty()).thenReturn(initialItemCount == 0, additionalEmptyStates);
}
}

View file

@ -17,10 +17,7 @@ package software.amazon.kinesis.lifecycle;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState;
@ -355,28 +352,17 @@ public class ConsumerStatesTest {
ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState();
assertThat(state.createTask(argument, consumer, null), nullValue());
verify(consumer, times(2)).shutdownNotification();
verify(shutdownNotification).shutdownComplete();
assertThat(state.successTransition(), equalTo(state));
for (ShutdownReason reason : ShutdownReason.values()) {
assertThat(state.shutdownTransition(reason), equalTo(state));
}
assertThat(state.isTerminal(), equalTo(true));
assertThat(state.state(), equalTo(ShardConsumerState.SHUTDOWN_COMPLETE));
assertThat(state.taskType(), equalTo(TaskType.SHUTDOWN_COMPLETE));
}
@Test
public void shutdownCompleteStateNullNotificationTest() {
ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState();
when(consumer.shutdownNotification()).thenReturn(null);
assertThat(state.createTask(argument, consumer, null), nullValue());
verify(consumer).shutdownNotification();
verify(shutdownNotification, never()).shutdownComplete();
}
static <ValueType> ReflectionPropertyMatcher<ShutdownTask, ValueType> shutdownTask(Class<ValueType> valueTypeClass,
String propertyName, Matcher<ValueType> matcher) {

View file

@ -114,6 +114,8 @@ public class ShutdownTaskTest {
private ShardRecordProcessor shardRecordProcessor;
@Mock
private LeaseCleanupManager leaseCleanupManager;
@Mock
private ShutdownNotification shutdownNotification;
@Before
public void setUp() throws Exception {
@ -308,6 +310,26 @@ public class ShutdownTaskTest {
verify(leaseRefresher, never()).createLeaseIfNotExists(any(Lease.class));
}
/**
* shutdownNotification is only set when ShardConsumer.gracefulShutdown() is called and should be null otherwise.
* The task should still call recordsPublisher.shutdown() regardless of the notification
*/
@Test
public void testCallWhenShutdownNotificationIsSet() {
final TaskResult result = createShutdownTaskWithNotification(LEASE_LOST, Collections.emptyList()).call();
assertNull(result.getException());
verify(recordsPublisher).shutdown();
verify(shutdownNotification).shutdownComplete();
}
@Test
public void testCallWhenShutdownNotificationIsNull() {
final TaskResult result = createShutdownTask(LEASE_LOST, Collections.emptyList()).call();
assertNull(result.getException());
verify(recordsPublisher).shutdown();
verify(shutdownNotification, never()).shutdownComplete();
}
/**
* Test method for {@link ShutdownTask#taskType()}.
*/
@ -372,7 +394,15 @@ public class ShutdownTaskTest {
private ShutdownTask createShutdownTask(final ShutdownReason reason, final List<ChildShard> childShards,
final ShardInfo shardInfo) {
return new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
reason, INITIAL_POSITION_TRIM_HORIZON, false, false,
reason, null, INITIAL_POSITION_TRIM_HORIZON, false, false,
leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, hierarchicalShardSyncer,
NULL_METRICS_FACTORY, childShards, STREAM_IDENTIFIER, leaseCleanupManager);
}
private ShutdownTask createShutdownTaskWithNotification(final ShutdownReason reason,
final List<ChildShard> childShards) {
return new ShutdownTask(SHARD_INFO, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
reason, shutdownNotification, INITIAL_POSITION_TRIM_HORIZON, false, false,
leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, hierarchicalShardSyncer,
NULL_METRICS_FACTORY, childShards, STREAM_IDENTIFIER, leaseCleanupManager);
}