Fixed retry storm in PrefetchRecordsPublisher. (#1062)

+ DRY in `PrefetchRecordsPublisherTest`
This commit is contained in:
stair 2023-03-13 19:57:01 -04:00 committed by GitHub
parent 504ea10859
commit c4204002af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 83 deletions

View file

@ -563,6 +563,9 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
if (timeSinceLastCall < idleMillisBetweenCalls) {
Thread.sleep(idleMillisBetweenCalls - timeSinceLastCall);
}
// avoid immediate-retry storms
lastSuccessfulCall = null;
}
}

View file

@ -21,7 +21,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
@ -32,13 +31,12 @@ import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static software.amazon.kinesis.utils.BlockingUtils.blockUntilConditionSatisfied;
import static software.amazon.kinesis.utils.BlockingUtils.blockUntilRecordsAvailable;
import static software.amazon.kinesis.utils.ProcessRecordsInputMatcher.eqProcessRecordsInput;
import java.time.Duration;
@ -83,7 +81,6 @@ import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException;
import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.RequestDetails;
import software.amazon.kinesis.leases.ShardObjectHelper;
import software.amazon.kinesis.common.StreamIdentifier;
import software.amazon.kinesis.lifecycle.ShardConsumerNotifyingSubscriber;
@ -95,6 +92,7 @@ import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.RecordsRetrieved;
import software.amazon.kinesis.retrieval.RetryableRetrievalException;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
import software.amazon.kinesis.utils.BlockingUtils;
/**
* Test class for the PrefetchRecordsPublisher class.
@ -107,10 +105,10 @@ public class PrefetchRecordsPublisherTest {
private static final int MAX_RECORDS_PER_CALL = 10000;
private static final int MAX_SIZE = 5;
private static final int MAX_RECORDS_COUNT = 15000;
private static final long IDLE_MILLIS_BETWEEN_CALLS = 0L;
private static final long AWAIT_TERMINATION_TIMEOUT = 1L;
private static final String NEXT_SHARD_ITERATOR = "testNextShardIterator";
private static final long DEFAULT_TIMEOUT_MILLIS = Duration.ofSeconds(1).toMillis();
@Mock
private GetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
@Mock
@ -124,28 +122,15 @@ public class PrefetchRecordsPublisherTest {
private ExecutorService executorService;
private LinkedBlockingQueue<PrefetchRecordsPublisher.PrefetchRecordsRetrieved> spyQueue;
private PrefetchRecordsPublisher getRecordsCache;
private String operation = "ProcessTask";
private GetRecordsResponse getRecordsResponse;
private Record record;
private RequestDetails requestDetails;
@Before
public void setup() {
when(getRecordsRetrievalStrategy.dataFetcher()).thenReturn(dataFetcher);
when(dataFetcher.getStreamIdentifier()).thenReturn(StreamIdentifier.singleStreamInstance("testStream"));
executorService = spy(Executors.newFixedThreadPool(1));
getRecordsCache = new PrefetchRecordsPublisher(
MAX_SIZE,
3 * SIZE_1_MB,
MAX_RECORDS_COUNT,
MAX_RECORDS_PER_CALL,
getRecordsRetrievalStrategy,
executorService,
IDLE_MILLIS_BETWEEN_CALLS,
new NullMetricsFactory(),
operation,
"shardId",
AWAIT_TERMINATION_TIMEOUT);
getRecordsCache = createPrefetchRecordsPublisher(0L);
spyQueue = spy(getRecordsCache.getPublisherSession().prefetchRecordsQueue());
records = spy(new ArrayList<>());
getRecordsResponse = GetRecordsResponse.builder().records(records).nextShardIterator(NEXT_SHARD_ITERATOR).childShards(new ArrayList<>()).build();
@ -158,7 +143,7 @@ public class PrefetchRecordsPublisherTest {
getRecordsCache.start(sequenceNumber, initialPosition);
getRecordsCache.start(sequenceNumber, initialPosition);
getRecordsCache.start(sequenceNumber, initialPosition);
verify(dataFetcher, times(1)).initialize(any(ExtendedSequenceNumber.class), any());
verify(dataFetcher).initialize(any(ExtendedSequenceNumber.class), any());
}
@Test
@ -189,7 +174,7 @@ public class PrefetchRecordsPublisherTest {
}
private void verifyInternalState(int queueSize) {
Assert.assertTrue(getRecordsCache.getPublisherSession().prefetchRecordsQueue().size() == queueSize);
assertEquals(queueSize, getRecordsCache.getPublisherSession().prefetchRecordsQueue().size());
}
@Test
@ -202,9 +187,7 @@ public class PrefetchRecordsPublisherTest {
.map(KinesisClientRecord::fromRecord).collect(Collectors.toList());
getRecordsCache.start(sequenceNumber, initialPosition);
ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache,
"shardId"), 1000L)
.processRecordsInput();
ProcessRecordsInput result = blockUntilRecordsAvailable().processRecordsInput();
assertEquals(expectedRecords, result.records());
assertEquals(new ArrayList<>(), result.childShards());
@ -215,19 +198,7 @@ public class PrefetchRecordsPublisherTest {
@Test(expected = RuntimeException.class)
public void testGetRecordsWithInitialFailures_LessThanRequiredWait_Throws() {
// Create a new PrefetchRecordsPublisher with 1s idle time between get calls
getRecordsCache = new PrefetchRecordsPublisher(
MAX_SIZE,
3 * SIZE_1_MB,
MAX_RECORDS_COUNT,
MAX_RECORDS_PER_CALL,
getRecordsRetrievalStrategy,
executorService,
1000,
new NullMetricsFactory(),
operation,
"shardId",
AWAIT_TERMINATION_TIMEOUT);
getRecordsCache = createPrefetchRecordsPublisher(Duration.ofSeconds(1).toMillis());
// Setup the retrieval strategy to fail initial calls before succeeding
when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenThrow(new
RetryableRetrievalException("Timed out")).thenThrow(new
@ -236,33 +207,15 @@ public class PrefetchRecordsPublisherTest {
when(records.size()).thenReturn(1000);
final List<KinesisClientRecord> expectedRecords = records.stream()
.map(KinesisClientRecord::fromRecord).collect(Collectors.toList());
getRecordsCache.start(sequenceNumber, initialPosition);
ProcessRecordsInput result = null;
// Setup timeout to be less than what the PrefetchRecordsPublisher will need based on the idle time between
// get calls to validate exception is thrown
result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache,
"shardId"), 1000L)
.processRecordsInput();
blockUntilRecordsAvailable();
}
@Test
public void testGetRecordsWithInitialFailures_AdequateWait_Success() {
// Create a new PrefetchRecordsPublisher with 1s idle time between get calls
getRecordsCache = new PrefetchRecordsPublisher(
MAX_SIZE,
3 * SIZE_1_MB,
MAX_RECORDS_COUNT,
MAX_RECORDS_PER_CALL,
getRecordsRetrievalStrategy,
executorService,
1000,
new NullMetricsFactory(),
operation,
"shardId",
AWAIT_TERMINATION_TIMEOUT);
getRecordsCache = createPrefetchRecordsPublisher(Duration.ofSeconds(1).toMillis());
// Setup the retrieval strategy to fail initial calls before succeeding
when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenThrow(new
RetryableRetrievalException("Timed out")).thenThrow(new
@ -278,8 +231,7 @@ public class PrefetchRecordsPublisherTest {
ProcessRecordsInput result = null;
// Setup timeout to be more than what the PrefetchRecordsPublisher will need based on the idle time between
// get calls and then validate the mocks later
result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache,
"shardId"), 4000L)
result = BlockingUtils.blockUntilRecordsAvailable(this::evictPublishedEvent, 4000L)
.processRecordsInput();
assertEquals(expectedRecords, result.records());
@ -303,8 +255,7 @@ public class PrefetchRecordsPublisherTest {
getRecordsCache.start(sequenceNumber, initialPosition);
try {
ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L)
.processRecordsInput();
blockUntilRecordsAvailable();
} catch (Exception e) {
assertEquals("No records found", e.getMessage());
}
@ -337,8 +288,7 @@ public class PrefetchRecordsPublisherTest {
when(dataFetcher.isShardEndReached()).thenReturn(true);
getRecordsCache.start(sequenceNumber, initialPosition);
ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L)
.processRecordsInput();
ProcessRecordsInput result = blockUntilRecordsAvailable().processRecordsInput();
assertEquals(expectedRecords, result.records());
assertEquals(childShards, result.childShards());
@ -406,7 +356,7 @@ public class PrefetchRecordsPublisherTest {
.map(KinesisClientRecord::fromRecord).collect(Collectors.toList());
getRecordsCache.start(sequenceNumber, initialPosition);
ProcessRecordsInput processRecordsInput = evictPublishedEvent(getRecordsCache, "shardId").processRecordsInput();
ProcessRecordsInput processRecordsInput = evictPublishedEvent().processRecordsInput();
verify(executorService).execute(any());
assertEquals(expectedRecords, processRecordsInput.records());
@ -415,7 +365,7 @@ public class PrefetchRecordsPublisherTest {
sleep(2000);
ProcessRecordsInput processRecordsInput2 = evictPublishedEvent(getRecordsCache, "shardId").processRecordsInput();
ProcessRecordsInput processRecordsInput2 = evictPublishedEvent().processRecordsInput();
assertNotEquals(processRecordsInput, processRecordsInput2);
assertEquals(expectedRecords, processRecordsInput2.records());
assertNotEquals(processRecordsInput2.timeSpentInCache(), Duration.ZERO);
@ -425,7 +375,7 @@ public class PrefetchRecordsPublisherTest {
@Test(expected = IllegalStateException.class)
public void testGetNextRecordsWithoutStarting() {
verify(executorService, times(0)).execute(any());
verify(executorService, never()).execute(any());
getRecordsCache.drainQueueForRequests();
}
@ -437,7 +387,6 @@ public class PrefetchRecordsPublisherTest {
@Test
public void testExpiredIteratorException() {
log.info("Starting tests");
when(getRecordsRetrievalStrategy.getRecords(MAX_RECORDS_PER_CALL)).thenThrow(ExpiredIteratorException.class)
.thenReturn(getRecordsResponse);
@ -445,7 +394,7 @@ public class PrefetchRecordsPublisherTest {
doNothing().when(dataFetcher).restartIterator();
blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L);
blockUntilRecordsAvailable();
sleep(1000);
@ -456,7 +405,6 @@ public class PrefetchRecordsPublisherTest {
public void testExpiredIteratorExceptionWithIllegalStateException() {
// This test validates that the daemon thread doesn't die when ExpiredIteratorException occurs with an
// IllegalStateException.
when(getRecordsRetrievalStrategy.getRecords(MAX_RECORDS_PER_CALL))
.thenThrow(ExpiredIteratorException.builder().build())
.thenReturn(getRecordsResponse)
@ -474,14 +422,13 @@ public class PrefetchRecordsPublisherTest {
@Test
public void testRetryableRetrievalExceptionContinues() {
GetRecordsResponse response = GetRecordsResponse.builder().millisBehindLatest(100L).records(Collections.emptyList()).nextShardIterator(NEXT_SHARD_ITERATOR).build();
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenThrow(new RetryableRetrievalException("Timeout", new TimeoutException("Timeout"))).thenReturn(response);
getRecordsCache.start(sequenceNumber, initialPosition);
RecordsRetrieved records = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000);
assertThat(records.processRecordsInput().millisBehindLatest(), equalTo(response.millisBehindLatest()));
RecordsRetrieved records = blockUntilRecordsAvailable();
assertEquals(records.processRecordsInput().millisBehindLatest(), response.millisBehindLatest());
}
@Test(timeout = 10000L)
@ -493,7 +440,6 @@ public class PrefetchRecordsPublisherTest {
// If the test times out before starting the subscriber it means something went wrong while filling the queue.
// After the subscriber is started one of the things that can trigger a timeout is a deadlock.
//
final int[] sequenceNumberInResponse = { 0 };
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenAnswer( i -> GetRecordsResponse.builder().records(
@ -681,14 +627,14 @@ public class PrefetchRecordsPublisherTest {
getRecordsCache.start(sequenceNumber, initialPosition);
RecordsRetrieved lastProcessed = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000);
RecordsRetrieved expected = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000);
RecordsRetrieved lastProcessed = blockUntilRecordsAvailable();
RecordsRetrieved expected = blockUntilRecordsAvailable();
//
// Skip some of the records the cache
//
blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000);
blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000);
blockUntilRecordsAvailable();
blockUntilRecordsAvailable();
verify(getRecordsRetrievalStrategy, atLeast(2)).getRecords(anyInt());
@ -697,16 +643,46 @@ public class PrefetchRecordsPublisherTest {
}
getRecordsCache.restartFrom(lastProcessed);
RecordsRetrieved postRestart = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000);
RecordsRetrieved postRestart = blockUntilRecordsAvailable();
assertThat(postRestart.processRecordsInput(), eqProcessRecordsInput(expected.processRecordsInput()));
verify(dataFetcher).resetIterator(eq(responses.get(0).nextShardIterator()),
eq(responses.get(0).records().get(0).sequenceNumber()), any());
}
private RecordsRetrieved evictPublishedEvent(PrefetchRecordsPublisher publisher, String shardId) {
return publisher.getPublisherSession().evictPublishedRecordAndUpdateDemand(shardId);
/**
* Tests that a thrown {@link SdkException} doesn't cause a retry storm.
*/
@Test(expected = RuntimeException.class)
public void testRepeatSdkExceptionLoop() {
final int expectedFailedCalls = 4;
getRecordsCache = createPrefetchRecordsPublisher(DEFAULT_TIMEOUT_MILLIS / expectedFailedCalls);
getRecordsCache.start(sequenceNumber, initialPosition);
try {
// return a valid response to cause `lastSuccessfulCall` to initialize
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(GetRecordsResponse.builder().build());
blockUntilRecordsAvailable();
} catch (RuntimeException re) {
Assert.fail("first call should succeed");
}
try {
when(getRecordsRetrievalStrategy.getRecords(anyInt()))
.thenThrow(SdkException.builder().message("lose yourself to dance").build());
blockUntilRecordsAvailable();
} finally {
// the successful call is the +1
verify(getRecordsRetrievalStrategy, times(expectedFailedCalls + 1)).getRecords(anyInt());
}
}
private RecordsRetrieved blockUntilRecordsAvailable() {
return BlockingUtils.blockUntilRecordsAvailable(this::evictPublishedEvent, DEFAULT_TIMEOUT_MILLIS);
}
private RecordsRetrieved evictPublishedEvent() {
return getRecordsCache.getPublisherSession().evictPublishedRecordAndUpdateDemand("shardId");
}
private static class RetrieverAnswer implements Answer<GetRecordsResponse> {
@ -736,7 +712,7 @@ public class PrefetchRecordsPublisherTest {
}
@Override
public GetRecordsResponse answer(InvocationOnMock invocation) throws Throwable {
public GetRecordsResponse answer(InvocationOnMock invocation) {
GetRecordsResponse response = iterator.next();
if (!iterator.hasNext()) {
iterator = responses.iterator();
@ -787,4 +763,19 @@ public class PrefetchRecordsPublisherTest {
return SdkBytes.fromByteArray(new byte[size]);
}
private PrefetchRecordsPublisher createPrefetchRecordsPublisher(final long idleMillisBetweenCalls) {
return new PrefetchRecordsPublisher(
MAX_SIZE,
3 * SIZE_1_MB,
MAX_RECORDS_COUNT,
MAX_RECORDS_PER_CALL,
getRecordsRetrievalStrategy,
executorService,
idleMillisBetweenCalls,
new NullMetricsFactory(),
PrefetchRecordsPublisherTest.class.getSimpleName(),
"shardId",
1L);
}
}