Fixing bug where initial subscription failure causes shard consumer to get stuck. (#562)
* Fixing bug where initial subscription fails cause shard consumer to get stuck. * Adding some comments for the changes and simplifying the unit test. * Adding unit tests for handling restart in case of rejection execution exception from executor service.
This commit is contained in:
parent
9e2d6fa497
commit
a150402e9c
2 changed files with 163 additions and 11 deletions
|
|
@ -44,6 +44,8 @@ class ShardConsumerSubscriber implements Subscriber<RecordsRetrieved> {
|
|||
|
||||
@VisibleForTesting
|
||||
final Object lockObject = new Object();
|
||||
// This holds the last time an attempt of request to upstream service was made including the first try to
|
||||
// establish subscription.
|
||||
private Instant lastRequestTime = null;
|
||||
private RecordsRetrieved lastAccepted = null;
|
||||
|
||||
|
|
@ -73,6 +75,9 @@ class ShardConsumerSubscriber implements Subscriber<RecordsRetrieved> {
|
|||
|
||||
void startSubscriptions() {
|
||||
synchronized (lockObject) {
|
||||
// Setting the lastRequestTime to allow for health checks to restart subscriptions if they failed to
|
||||
// during initial try.
|
||||
lastRequestTime = Instant.now();
|
||||
if (lastAccepted != null) {
|
||||
recordsPublisher.restartFrom(lastAccepted);
|
||||
}
|
||||
|
|
@ -127,12 +132,8 @@ class ShardConsumerSubscriber implements Subscriber<RecordsRetrieved> {
|
|||
"{}: Last request was dispatched at {}, but no response as of {} ({}). Cancelling subscription, and restarting.",
|
||||
shardConsumer.shardInfo().shardId(), lastRequestTime, now, timeSinceLastResponse);
|
||||
cancel();
|
||||
//
|
||||
// Set the last request time to now, we specifically don't null it out since we want it to
|
||||
// trigger a
|
||||
// restart if the subscription still doesn't start producing.
|
||||
//
|
||||
lastRequestTime = Instant.now();
|
||||
|
||||
// Start the subscription again which will update the lastRequestTime as well.
|
||||
startSubscriptions();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ import static org.mockito.Matchers.any;
|
|||
import static org.mockito.Matchers.argThat;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
|
@ -36,6 +38,7 @@ import java.util.List;
|
|||
import java.util.concurrent.CyclicBarrier;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.RejectedExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
|
@ -47,7 +50,6 @@ import org.junit.Test;
|
|||
import org.junit.rules.TestName;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Mockito;
|
||||
import org.mockito.invocation.InvocationOnMock;
|
||||
import org.mockito.runners.MockitoJUnitRunner;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
|
@ -64,7 +66,6 @@ import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
|||
import software.amazon.kinesis.retrieval.KinesisClientRecord;
|
||||
import software.amazon.kinesis.retrieval.RecordsPublisher;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.RetryableRetrievalException;
|
||||
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
||||
|
||||
@Slf4j
|
||||
|
|
@ -311,6 +312,133 @@ public class ShardConsumerSubscriberTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void restartAfterRequestTimerExpiresWhenNotGettingRecordsAfterInitialization() throws Exception {
|
||||
|
||||
executorService = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder()
|
||||
.setNameFormat("test-" + testName.getMethodName() + "-%04d").setDaemon(true).build());
|
||||
|
||||
// Mock record publisher which doesn't publish any records on first try which simulates any scenario which
|
||||
// causes first subscription try to fail.
|
||||
recordsPublisher = new RecordPublisherWithInitialFailureSubscription();
|
||||
subscriber = new ShardConsumerSubscriber(recordsPublisher, executorService, bufferSize, shardConsumer, 0);
|
||||
addUniqueItem(1);
|
||||
|
||||
List<ProcessRecordsInput> received = new ArrayList<>();
|
||||
doAnswer(a -> {
|
||||
ProcessRecordsInput input = a.getArgumentAt(0, ProcessRecordsInput.class);
|
||||
received.add(input);
|
||||
if (input.records().stream().anyMatch(r -> StringUtils.startsWith(r.partitionKey(), TERMINAL_MARKER))) {
|
||||
synchronized (processedNotifier) {
|
||||
processedNotifier.notifyAll();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class));
|
||||
|
||||
// First try to start subscriptions.
|
||||
synchronized (processedNotifier) {
|
||||
subscriber.startSubscriptions();
|
||||
}
|
||||
|
||||
// Verifying that there are no interactions with shardConsumer mock indicating no records were sent back and
|
||||
// subscription has not started correctly.
|
||||
verify(shardConsumer, never()).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
|
||||
Stream.iterate(2, i -> i + 1).limit(98).forEach(this::addUniqueItem);
|
||||
|
||||
addTerminalMarker(2);
|
||||
|
||||
// Doing the health check to allow the subscription to restart.
|
||||
assertThat(subscriber.healthCheck(1), nullValue());
|
||||
|
||||
// Allow time for processing of the records to end in the executor thread which call notifyAll as it gets the
|
||||
// terminal record. Keeping the timeout pretty high for avoiding test failures on slow machines.
|
||||
synchronized (processedNotifier) {
|
||||
processedNotifier.wait(1000);
|
||||
}
|
||||
|
||||
// Verify that shardConsumer mock was called 100 times and all 100 input records are processed.
|
||||
verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
|
||||
// Verify that received records in the subscriber are equal to the ones sent by the record publisher.
|
||||
assertThat(received.size(), equalTo(recordsPublisher.responses.size()));
|
||||
Stream.iterate(0, i -> i + 1).limit(received.size()).forEach(i -> assertThat(received.get(i),
|
||||
eqProcessRecordsInput(recordsPublisher.responses.get(i).recordsRetrieved.processRecordsInput())));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void restartAfterRequestTimerExpiresWhenInitialTaskExecutionIsRejected() throws Exception {
|
||||
|
||||
executorService = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder()
|
||||
.setNameFormat("test-" + testName.getMethodName() + "-%04d").setDaemon(true).build());
|
||||
|
||||
ExecutorService failingService = spy(executorService);
|
||||
|
||||
doAnswer(invocation -> directlyExecuteRunnable(invocation))
|
||||
.doThrow(new RejectedExecutionException())
|
||||
.doCallRealMethod()
|
||||
.when(failingService).execute(any());
|
||||
|
||||
subscriber = new ShardConsumerSubscriber(recordsPublisher, failingService, bufferSize, shardConsumer, 0);
|
||||
addUniqueItem(1);
|
||||
|
||||
List<ProcessRecordsInput> received = new ArrayList<>();
|
||||
doAnswer(a -> {
|
||||
ProcessRecordsInput input = a.getArgumentAt(0, ProcessRecordsInput.class);
|
||||
received.add(input);
|
||||
if (input.records().stream().anyMatch(r -> StringUtils.startsWith(r.partitionKey(), TERMINAL_MARKER))) {
|
||||
synchronized (processedNotifier) {
|
||||
processedNotifier.notifyAll();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class));
|
||||
|
||||
// First try to start subscriptions.
|
||||
synchronized (processedNotifier) {
|
||||
subscriber.startSubscriptions();
|
||||
}
|
||||
|
||||
// Verifying that there are no interactions with shardConsumer mock indicating no records were sent back and
|
||||
// subscription has not started correctly.
|
||||
verify(shardConsumer, never()).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
|
||||
Stream.iterate(2, i -> i + 1).limit(98).forEach(this::addUniqueItem);
|
||||
|
||||
addTerminalMarker(2);
|
||||
|
||||
// Doing the health check to allow the subscription to restart.
|
||||
assertThat(subscriber.healthCheck(1), nullValue());
|
||||
|
||||
// Allow time for processing of the records to end in the executor thread which call notifyAll as it gets the
|
||||
// terminal record. Keeping the timeout pretty high for avoiding test failures on slow machines.
|
||||
synchronized (processedNotifier) {
|
||||
processedNotifier.wait(1000);
|
||||
}
|
||||
|
||||
// Verify that shardConsumer mock was called 100 times and all 100 input records are processed.
|
||||
verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
|
||||
// Verify that received records in the subscriber are equal to the ones sent by the record publisher.
|
||||
assertThat(received.size(), equalTo(recordsPublisher.responses.size()));
|
||||
Stream.iterate(0, i -> i + 1).limit(received.size()).forEach(i -> assertThat(received.get(i),
|
||||
eqProcessRecordsInput(recordsPublisher.responses.get(i).recordsRetrieved.processRecordsInput())));
|
||||
|
||||
}
|
||||
|
||||
private Object directlyExecuteRunnable(InvocationOnMock invocation) {
|
||||
Object[] args = invocation.getArguments();
|
||||
Runnable runnable = (Runnable) args[0];
|
||||
runnable.run();
|
||||
return null;
|
||||
}
|
||||
|
||||
private void addUniqueItem(int id) {
|
||||
RecordsRetrieved r = mock(RecordsRetrieved.class, "Record-" + id);
|
||||
ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now())
|
||||
|
|
@ -373,9 +501,9 @@ public class ShardConsumerSubscriberTest {
|
|||
private class TestPublisher implements RecordsPublisher {
|
||||
|
||||
private final LinkedList<ResponseItem> responses = new LinkedList<>();
|
||||
private volatile long requested = 0;
|
||||
protected volatile long requested = 0;
|
||||
private int currentIndex = 0;
|
||||
private Subscriber<? super RecordsRetrieved> subscriber;
|
||||
protected Subscriber<? super RecordsRetrieved> subscriber;
|
||||
private RecordsRetrieved restartedFrom;
|
||||
|
||||
void add(ResponseItem... toAdd) {
|
||||
|
|
@ -448,6 +576,29 @@ public class ShardConsumerSubscriberTest {
|
|||
}
|
||||
}
|
||||
|
||||
private class RecordPublisherWithInitialFailureSubscription extends TestPublisher {
|
||||
private int subscriptionTryCount = 0;
|
||||
|
||||
@Override
|
||||
public void subscribe(Subscriber<? super RecordsRetrieved> s) {
|
||||
subscriber = s;
|
||||
++subscriptionTryCount;
|
||||
s.onSubscribe(new Subscription() {
|
||||
@Override
|
||||
public void request(long n) {
|
||||
if (subscriptionTryCount != 1) {
|
||||
send(n);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel() {
|
||||
requested = 0;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
class TestShardConsumerSubscriber extends ShardConsumerSubscriber {
|
||||
|
||||
private int genericWarningLogged = 0;
|
||||
|
|
@ -685,4 +836,4 @@ public class ShardConsumerSubscriberTest {
|
|||
consumer.startSubscriptions();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue