From c4204002af350465c6633e96d5c7bf1950f89bf6 Mon Sep 17 00:00:00 2001 From: stair <123031771+stair-aws@users.noreply.github.com> Date: Mon, 13 Mar 2023 19:57:01 -0400 Subject: [PATCH] Fixed retry storm in `PrefetchRecordsPublisher`. (#1062) + DRY in `PrefetchRecordsPublisherTest` --- .../polling/PrefetchRecordsPublisher.java | 3 + .../polling/PrefetchRecordsPublisherTest.java | 157 +++++++++--------- 2 files changed, 77 insertions(+), 83 deletions(-) diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java index ab406244..07f4aaac 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java @@ -563,6 +563,9 @@ public class PrefetchRecordsPublisher implements RecordsPublisher { if (timeSinceLastCall < idleMillisBetweenCalls) { Thread.sleep(idleMillisBetweenCalls - timeSinceLastCall); } + + // avoid immediate-retry storms + lastSuccessfulCall = null; } } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java index 55d76432..74707eb4 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java @@ -21,7 +21,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; @@ -32,13 +31,12 @@ import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static software.amazon.kinesis.utils.BlockingUtils.blockUntilConditionSatisfied; -import static software.amazon.kinesis.utils.BlockingUtils.blockUntilRecordsAvailable; import static software.amazon.kinesis.utils.ProcessRecordsInputMatcher.eqProcessRecordsInput; import java.time.Duration; @@ -83,7 +81,6 @@ import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.kinesis.common.InitialPositionInStreamExtended; -import software.amazon.kinesis.common.RequestDetails; import software.amazon.kinesis.leases.ShardObjectHelper; import software.amazon.kinesis.common.StreamIdentifier; import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber; @@ -95,6 +92,7 @@ import software.amazon.kinesis.retrieval.RecordsPublisher; import software.amazon.kinesis.retrieval.RecordsRetrieved; import software.amazon.kinesis.retrieval.RetryableRetrievalException; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; +import software.amazon.kinesis.utils.BlockingUtils; /** * Test class for the PrefetchRecordsPublisher class. @@ -107,10 +105,10 @@ public class PrefetchRecordsPublisherTest { private static final int MAX_RECORDS_PER_CALL = 10000; private static final int MAX_SIZE = 5; private static final int MAX_RECORDS_COUNT = 15000; - private static final long IDLE_MILLIS_BETWEEN_CALLS = 0L; - private static final long AWAIT_TERMINATION_TIMEOUT = 1L; private static final String NEXT_SHARD_ITERATOR = "testNextShardIterator"; + private static final long DEFAULT_TIMEOUT_MILLIS = Duration.ofSeconds(1).toMillis(); + @Mock private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy; @Mock @@ -124,28 +122,15 @@ public class PrefetchRecordsPublisherTest { private ExecutorService executorService; private LinkedBlockingQueue spyQueue; private PrefetchRecordsPublisher getRecordsCache; - private String operation = "ProcessTask"; private GetRecordsResponse getRecordsResponse; private Record record; - private RequestDetails requestDetails; @Before public void setup() { when(getRecordsRetrievalStrategy.dataFetcher()).thenReturn(dataFetcher); when(dataFetcher.getStreamIdentifier()).thenReturn(StreamIdentifier.singleStreamInstance("testStream")); executorService = spy(Executors.newFixedThreadPool(1)); - getRecordsCache = new PrefetchRecordsPublisher( - MAX_SIZE, - 3 * SIZE_1_MB, - MAX_RECORDS_COUNT, - MAX_RECORDS_PER_CALL, - getRecordsRetrievalStrategy, - executorService, - IDLE_MILLIS_BETWEEN_CALLS, - new NullMetricsFactory(), - operation, - "shardId", - AWAIT_TERMINATION_TIMEOUT); + getRecordsCache = createPrefetchRecordsPublisher(0L); spyQueue = spy(getRecordsCache.getPublisherSession().prefetchRecordsQueue()); records = spy(new ArrayList<>()); getRecordsResponse = GetRecordsResponse.builder().records(records).nextShardIterator(NEXT_SHARD_ITERATOR).childShards(new ArrayList<>()).build(); @@ -158,7 +143,7 @@ public class PrefetchRecordsPublisherTest { getRecordsCache.start(sequenceNumber, initialPosition); getRecordsCache.start(sequenceNumber, initialPosition); getRecordsCache.start(sequenceNumber, initialPosition); - verify(dataFetcher, times(1)).initialize(any(ExtendedSequenceNumber.class), any()); + verify(dataFetcher).initialize(any(ExtendedSequenceNumber.class), any()); } @Test @@ -189,7 +174,7 @@ public class PrefetchRecordsPublisherTest { } private void verifyInternalState(int queueSize) { - Assert.assertTrue(getRecordsCache.getPublisherSession().prefetchRecordsQueue().size() == queueSize); + assertEquals(queueSize, getRecordsCache.getPublisherSession().prefetchRecordsQueue().size()); } @Test @@ -202,9 +187,7 @@ public class PrefetchRecordsPublisherTest { .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); getRecordsCache.start(sequenceNumber, initialPosition); - ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, - "shardId"), 1000L) - .processRecordsInput(); + ProcessRecordsInput result = blockUntilRecordsAvailable().processRecordsInput(); assertEquals(expectedRecords, result.records()); assertEquals(new ArrayList<>(), result.childShards()); @@ -215,19 +198,7 @@ public class PrefetchRecordsPublisherTest { @Test(expected = RuntimeException.class) public void testGetRecordsWithInitialFailures_LessThanRequiredWait_Throws() { - // Create a new PrefetchRecordsPublisher with 1s idle time between get calls - getRecordsCache = new PrefetchRecordsPublisher( - MAX_SIZE, - 3 * SIZE_1_MB, - MAX_RECORDS_COUNT, - MAX_RECORDS_PER_CALL, - getRecordsRetrievalStrategy, - executorService, - 1000, - new NullMetricsFactory(), - operation, - "shardId", - AWAIT_TERMINATION_TIMEOUT); + getRecordsCache = createPrefetchRecordsPublisher(Duration.ofSeconds(1).toMillis()); // Setup the retrieval strategy to fail initial calls before succeeding when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenThrow(new RetryableRetrievalException("Timed out")).thenThrow(new @@ -236,33 +207,15 @@ public class PrefetchRecordsPublisherTest { when(records.size()).thenReturn(1000); - final List expectedRecords = records.stream() - .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); - getRecordsCache.start(sequenceNumber, initialPosition); - ProcessRecordsInput result = null; // Setup timeout to be less than what the PrefetchRecordsPublisher will need based on the idle time between // get calls to validate exception is thrown - result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, - "shardId"), 1000L) - .processRecordsInput(); + blockUntilRecordsAvailable(); } @Test public void testGetRecordsWithInitialFailures_AdequateWait_Success() { - // Create a new PrefetchRecordsPublisher with 1s idle time between get calls - getRecordsCache = new PrefetchRecordsPublisher( - MAX_SIZE, - 3 * SIZE_1_MB, - MAX_RECORDS_COUNT, - MAX_RECORDS_PER_CALL, - getRecordsRetrievalStrategy, - executorService, - 1000, - new NullMetricsFactory(), - operation, - "shardId", - AWAIT_TERMINATION_TIMEOUT); + getRecordsCache = createPrefetchRecordsPublisher(Duration.ofSeconds(1).toMillis()); // Setup the retrieval strategy to fail initial calls before succeeding when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenThrow(new RetryableRetrievalException("Timed out")).thenThrow(new @@ -278,8 +231,7 @@ public class PrefetchRecordsPublisherTest { ProcessRecordsInput result = null; // Setup timeout to be more than what the PrefetchRecordsPublisher will need based on the idle time between // get calls and then validate the mocks later - result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, - "shardId"), 4000L) + result = BlockingUtils.blockUntilRecordsAvailable(this::evictPublishedEvent, 4000L) .processRecordsInput(); assertEquals(expectedRecords, result.records()); @@ -303,8 +255,7 @@ public class PrefetchRecordsPublisherTest { getRecordsCache.start(sequenceNumber, initialPosition); try { - ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L) - .processRecordsInput(); + blockUntilRecordsAvailable(); } catch (Exception e) { assertEquals("No records found", e.getMessage()); } @@ -337,8 +288,7 @@ public class PrefetchRecordsPublisherTest { when(dataFetcher.isShardEndReached()).thenReturn(true); getRecordsCache.start(sequenceNumber, initialPosition); - ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L) - .processRecordsInput(); + ProcessRecordsInput result = blockUntilRecordsAvailable().processRecordsInput(); assertEquals(expectedRecords, result.records()); assertEquals(childShards, result.childShards()); @@ -406,7 +356,7 @@ public class PrefetchRecordsPublisherTest { .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); getRecordsCache.start(sequenceNumber, initialPosition); - ProcessRecordsInput processRecordsInput = evictPublishedEvent(getRecordsCache, "shardId").processRecordsInput(); + ProcessRecordsInput processRecordsInput = evictPublishedEvent().processRecordsInput(); verify(executorService).execute(any()); assertEquals(expectedRecords, processRecordsInput.records()); @@ -415,7 +365,7 @@ public class PrefetchRecordsPublisherTest { sleep(2000); - ProcessRecordsInput processRecordsInput2 = evictPublishedEvent(getRecordsCache, "shardId").processRecordsInput(); + ProcessRecordsInput processRecordsInput2 = evictPublishedEvent().processRecordsInput(); assertNotEquals(processRecordsInput, processRecordsInput2); assertEquals(expectedRecords, processRecordsInput2.records()); assertNotEquals(processRecordsInput2.timeSpentInCache(), Duration.ZERO); @@ -425,7 +375,7 @@ public class PrefetchRecordsPublisherTest { @Test(expected = IllegalStateException.class) public void testGetNextRecordsWithoutStarting() { - verify(executorService, times(0)).execute(any()); + verify(executorService, never()).execute(any()); getRecordsCache.drainQueueForRequests(); } @@ -437,7 +387,6 @@ public class PrefetchRecordsPublisherTest { @Test public void testExpiredIteratorException() { - log.info("Starting tests"); when(getRecordsRetrievalStrategy.getRecords(MAX_RECORDS_PER_CALL)).thenThrow(ExpiredIteratorException.class) .thenReturn(getRecordsResponse); @@ -445,7 +394,7 @@ public class PrefetchRecordsPublisherTest { doNothing().when(dataFetcher).restartIterator(); - blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L); + blockUntilRecordsAvailable(); sleep(1000); @@ -456,7 +405,6 @@ public class PrefetchRecordsPublisherTest { public void testExpiredIteratorExceptionWithIllegalStateException() { // This test validates that the daemon thread doesn't die when ExpiredIteratorException occurs with an // IllegalStateException. - when(getRecordsRetrievalStrategy.getRecords(MAX_RECORDS_PER_CALL)) .thenThrow(ExpiredIteratorException.builder().build()) .thenReturn(getRecordsResponse) @@ -474,14 +422,13 @@ public class PrefetchRecordsPublisherTest { @Test public void testRetryableRetrievalExceptionContinues() { - GetRecordsResponse response = GetRecordsResponse.builder().millisBehindLatest(100L).records(Collections.emptyList()).nextShardIterator(NEXT_SHARD_ITERATOR).build(); when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenThrow(new RetryableRetrievalException("Timeout", new TimeoutException("Timeout"))).thenReturn(response); getRecordsCache.start(sequenceNumber, initialPosition); - RecordsRetrieved records = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000); - assertThat(records.processRecordsInput().millisBehindLatest(), equalTo(response.millisBehindLatest())); + RecordsRetrieved records = blockUntilRecordsAvailable(); + assertEquals(records.processRecordsInput().millisBehindLatest(), response.millisBehindLatest()); } @Test(timeout = 10000L) @@ -493,7 +440,6 @@ public class PrefetchRecordsPublisherTest { // If the test times out before starting the subscriber it means something went wrong while filling the queue. // After the subscriber is started one of the things that can trigger a timeout is a deadlock. // - final int[] sequenceNumberInResponse = { 0 }; when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenAnswer( i -> GetRecordsResponse.builder().records( @@ -681,14 +627,14 @@ public class PrefetchRecordsPublisherTest { getRecordsCache.start(sequenceNumber, initialPosition); - RecordsRetrieved lastProcessed = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000); - RecordsRetrieved expected = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000); + RecordsRetrieved lastProcessed = blockUntilRecordsAvailable(); + RecordsRetrieved expected = blockUntilRecordsAvailable(); // // Skip some of the records the cache // - blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000); - blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000); + blockUntilRecordsAvailable(); + blockUntilRecordsAvailable(); verify(getRecordsRetrievalStrategy, atLeast(2)).getRecords(anyInt()); @@ -697,16 +643,46 @@ public class PrefetchRecordsPublisherTest { } getRecordsCache.restartFrom(lastProcessed); - RecordsRetrieved postRestart = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000); + RecordsRetrieved postRestart = blockUntilRecordsAvailable(); assertThat(postRestart.processRecordsInput(), eqProcessRecordsInput(expected.processRecordsInput())); verify(dataFetcher).resetIterator(eq(responses.get(0).nextShardIterator()), eq(responses.get(0).records().get(0).sequenceNumber()), any()); - } - private RecordsRetrieved evictPublishedEvent(PrefetchRecordsPublisher publisher, String shardId) { - return publisher.getPublisherSession().evictPublishedRecordAndUpdateDemand(shardId); + /** + * Tests that a thrown {@link SdkException} doesn't cause a retry storm. + */ + @Test(expected = RuntimeException.class) + public void testRepeatSdkExceptionLoop() { + final int expectedFailedCalls = 4; + getRecordsCache = createPrefetchRecordsPublisher(DEFAULT_TIMEOUT_MILLIS / expectedFailedCalls); + getRecordsCache.start(sequenceNumber, initialPosition); + + try { + // return a valid response to cause `lastSuccessfulCall` to initialize + when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(GetRecordsResponse.builder().build()); + blockUntilRecordsAvailable(); + } catch (RuntimeException re) { + Assert.fail("first call should succeed"); + } + + try { + when(getRecordsRetrievalStrategy.getRecords(anyInt())) + .thenThrow(SdkException.builder().message("lose yourself to dance").build()); + blockUntilRecordsAvailable(); + } finally { + // the successful call is the +1 + verify(getRecordsRetrievalStrategy, times(expectedFailedCalls + 1)).getRecords(anyInt()); + } + } + + private RecordsRetrieved blockUntilRecordsAvailable() { + return BlockingUtils.blockUntilRecordsAvailable(this::evictPublishedEvent, DEFAULT_TIMEOUT_MILLIS); + } + + private RecordsRetrieved evictPublishedEvent() { + return getRecordsCache.getPublisherSession().evictPublishedRecordAndUpdateDemand("shardId"); } private static class RetrieverAnswer implements Answer { @@ -736,7 +712,7 @@ public class PrefetchRecordsPublisherTest { } @Override - public GetRecordsResponse answer(InvocationOnMock invocation) throws Throwable { + public GetRecordsResponse answer(InvocationOnMock invocation) { GetRecordsResponse response = iterator.next(); if (!iterator.hasNext()) { iterator = responses.iterator(); @@ -787,4 +763,19 @@ public class PrefetchRecordsPublisherTest { return SdkBytes.fromByteArray(new byte[size]); } + private PrefetchRecordsPublisher createPrefetchRecordsPublisher(final long idleMillisBetweenCalls) { + return new PrefetchRecordsPublisher( + MAX_SIZE, + 3 * SIZE_1_MB, + MAX_RECORDS_COUNT, + MAX_RECORDS_PER_CALL, + getRecordsRetrievalStrategy, + executorService, + idleMillisBetweenCalls, + new NullMetricsFactory(), + PrefetchRecordsPublisherTest.class.getSimpleName(), + "shardId", + 1L); + } + }