From a150402e9cd8cbed08ae16dcb0069363eefffec4 Mon Sep 17 00:00:00 2001 From: yatins47 Date: Thu, 11 Jul 2019 07:16:04 -0700 Subject: [PATCH] Fixing bug where initial subscription failure causes shard consumer to get stuck. (#562) * Fixing bug where initial subscription fails cause shard consumer to get stuck. * Adding some comments for the changes and simplifying the unit test. * Adding unit tests for handling restart in case of rejection execution exception from executor service. --- .../lifecycle/ShardConsumerSubscriber.java | 13 +- .../ShardConsumerSubscriberTest.java | 161 +++++++++++++++++- 2 files changed, 163 insertions(+), 11 deletions(-) diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriber.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriber.java index afc22f70..102982f2 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriber.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriber.java @@ -44,6 +44,8 @@ class ShardConsumerSubscriber implements Subscriber { @VisibleForTesting final Object lockObject = new Object(); + // This holds the last time an attempt of request to upstream service was made including the first try to + // establish subscription. private Instant lastRequestTime = null; private RecordsRetrieved lastAccepted = null; @@ -73,6 +75,9 @@ class ShardConsumerSubscriber implements Subscriber { void startSubscriptions() { synchronized (lockObject) { + // Setting the lastRequestTime to allow for health checks to restart subscriptions if they failed to + // during initial try. + lastRequestTime = Instant.now(); if (lastAccepted != null) { recordsPublisher.restartFrom(lastAccepted); } @@ -127,12 +132,8 @@ class ShardConsumerSubscriber implements Subscriber { "{}: Last request was dispatched at {}, but no response as of {} ({}). Cancelling subscription, and restarting.", shardConsumer.shardInfo().shardId(), lastRequestTime, now, timeSinceLastResponse); cancel(); - // - // Set the last request time to now, we specifically don't null it out since we want it to - // trigger a - // restart if the subscription still doesn't start producing. - // - lastRequestTime = Instant.now(); + + // Start the subscription again which will update the lastRequestTime as well. startSubscriptions(); } } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java index bfe508e6..494dd282 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java @@ -22,6 +22,8 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -36,6 +38,7 @@ import java.util.List; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; @@ -47,7 +50,6 @@ import org.junit.Test; import org.junit.rules.TestName; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; @@ -64,7 +66,6 @@ import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.retrieval.KinesisClientRecord; import software.amazon.kinesis.retrieval.RecordsPublisher; import software.amazon.kinesis.retrieval.RecordsRetrieved; -import software.amazon.kinesis.retrieval.RetryableRetrievalException; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; @Slf4j @@ -311,6 +312,133 @@ public class ShardConsumerSubscriberTest { } + @Test + public void restartAfterRequestTimerExpiresWhenNotGettingRecordsAfterInitialization() throws Exception { + + executorService = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder() + .setNameFormat("test-" + testName.getMethodName() + "-%04d").setDaemon(true).build()); + + // Mock record publisher which doesn't publish any records on first try which simulates any scenario which + // causes first subscription try to fail. + recordsPublisher = new RecordPublisherWithInitialFailureSubscription(); + subscriber = new ShardConsumerSubscriber(recordsPublisher, executorService, bufferSize, shardConsumer, 0); + addUniqueItem(1); + + List received = new ArrayList<>(); + doAnswer(a -> { + ProcessRecordsInput input = a.getArgumentAt(0, ProcessRecordsInput.class); + received.add(input); + if (input.records().stream().anyMatch(r -> StringUtils.startsWith(r.partitionKey(), TERMINAL_MARKER))) { + synchronized (processedNotifier) { + processedNotifier.notifyAll(); + } + } + return null; + }).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + + // First try to start subscriptions. + synchronized (processedNotifier) { + subscriber.startSubscriptions(); + } + + // Verifying that there are no interactions with shardConsumer mock indicating no records were sent back and + // subscription has not started correctly. + verify(shardConsumer, never()).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + + Stream.iterate(2, i -> i + 1).limit(98).forEach(this::addUniqueItem); + + addTerminalMarker(2); + + // Doing the health check to allow the subscription to restart. + assertThat(subscriber.healthCheck(1), nullValue()); + + // Allow time for processing of the records to end in the executor thread which call notifyAll as it gets the + // terminal record. Keeping the timeout pretty high for avoiding test failures on slow machines. + synchronized (processedNotifier) { + processedNotifier.wait(1000); + } + + // Verify that shardConsumer mock was called 100 times and all 100 input records are processed. + verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + + // Verify that received records in the subscriber are equal to the ones sent by the record publisher. + assertThat(received.size(), equalTo(recordsPublisher.responses.size())); + Stream.iterate(0, i -> i + 1).limit(received.size()).forEach(i -> assertThat(received.get(i), + eqProcessRecordsInput(recordsPublisher.responses.get(i).recordsRetrieved.processRecordsInput()))); + + } + + @Test + public void restartAfterRequestTimerExpiresWhenInitialTaskExecutionIsRejected() throws Exception { + + executorService = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder() + .setNameFormat("test-" + testName.getMethodName() + "-%04d").setDaemon(true).build()); + + ExecutorService failingService = spy(executorService); + + doAnswer(invocation -> directlyExecuteRunnable(invocation)) + .doThrow(new RejectedExecutionException()) + .doCallRealMethod() + .when(failingService).execute(any()); + + subscriber = new ShardConsumerSubscriber(recordsPublisher, failingService, bufferSize, shardConsumer, 0); + addUniqueItem(1); + + List received = new ArrayList<>(); + doAnswer(a -> { + ProcessRecordsInput input = a.getArgumentAt(0, ProcessRecordsInput.class); + received.add(input); + if (input.records().stream().anyMatch(r -> StringUtils.startsWith(r.partitionKey(), TERMINAL_MARKER))) { + synchronized (processedNotifier) { + processedNotifier.notifyAll(); + } + } + return null; + }).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + + // First try to start subscriptions. + synchronized (processedNotifier) { + subscriber.startSubscriptions(); + } + + // Verifying that there are no interactions with shardConsumer mock indicating no records were sent back and + // subscription has not started correctly. + verify(shardConsumer, never()).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + + Stream.iterate(2, i -> i + 1).limit(98).forEach(this::addUniqueItem); + + addTerminalMarker(2); + + // Doing the health check to allow the subscription to restart. + assertThat(subscriber.healthCheck(1), nullValue()); + + // Allow time for processing of the records to end in the executor thread which call notifyAll as it gets the + // terminal record. Keeping the timeout pretty high for avoiding test failures on slow machines. + synchronized (processedNotifier) { + processedNotifier.wait(1000); + } + + // Verify that shardConsumer mock was called 100 times and all 100 input records are processed. + verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + + // Verify that received records in the subscriber are equal to the ones sent by the record publisher. + assertThat(received.size(), equalTo(recordsPublisher.responses.size())); + Stream.iterate(0, i -> i + 1).limit(received.size()).forEach(i -> assertThat(received.get(i), + eqProcessRecordsInput(recordsPublisher.responses.get(i).recordsRetrieved.processRecordsInput()))); + + } + + private Object directlyExecuteRunnable(InvocationOnMock invocation) { + Object[] args = invocation.getArguments(); + Runnable runnable = (Runnable) args[0]; + runnable.run(); + return null; + } + private void addUniqueItem(int id) { RecordsRetrieved r = mock(RecordsRetrieved.class, "Record-" + id); ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now()) @@ -373,9 +501,9 @@ public class ShardConsumerSubscriberTest { private class TestPublisher implements RecordsPublisher { private final LinkedList responses = new LinkedList<>(); - private volatile long requested = 0; + protected volatile long requested = 0; private int currentIndex = 0; - private Subscriber subscriber; + protected Subscriber subscriber; private RecordsRetrieved restartedFrom; void add(ResponseItem... toAdd) { @@ -448,6 +576,29 @@ public class ShardConsumerSubscriberTest { } } + private class RecordPublisherWithInitialFailureSubscription extends TestPublisher { + private int subscriptionTryCount = 0; + + @Override + public void subscribe(Subscriber s) { + subscriber = s; + ++subscriptionTryCount; + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + if (subscriptionTryCount != 1) { + send(n); + } + } + + @Override + public void cancel() { + requested = 0; + } + }); + } + } + class TestShardConsumerSubscriber extends ShardConsumerSubscriber { private int genericWarningLogged = 0; @@ -685,4 +836,4 @@ public class ShardConsumerSubscriberTest { consumer.startSubscriptions(); } -} \ No newline at end of file +}