diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java index ef752f1b..9373aa42 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java @@ -42,6 +42,7 @@ import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.cloudwatch.model.StandardUnit; import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; +import software.amazon.awssdk.services.kinesis.model.ProvisionedThroughputExceededException; import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.RequestDetails; @@ -86,6 +87,7 @@ public class PrefetchRecordsPublisher implements RecordsPublisher { private final MetricsFactory metricsFactory; private final long idleMillisBetweenCalls; private Instant lastSuccessfulCall; + private boolean isFirstGetCallTry = true; private final DefaultGetRecordsCacheDaemon defaultGetRecordsCacheDaemon; private boolean started = false; private final String operation; @@ -459,6 +461,11 @@ public class PrefetchRecordsPublisher implements RecordsPublisher { scope.addData(EXPIRED_ITERATOR_METRIC, 1, StandardUnit.COUNT, MetricsLevel.SUMMARY); publisherSession.dataFetcher().restartIterator(); + } catch (ProvisionedThroughputExceededException e) { + // Update the lastSuccessfulCall if we get a throttling exception so that we back off idleMillis + // for the next call + lastSuccessfulCall = Instant.now(); + log.error("{} : Exception thrown while fetching records from Kinesis", streamAndShardId, e); } catch (SdkException e) { log.error("{} : Exception thrown while fetching records from Kinesis", streamAndShardId, e); } catch (Throwable e) { @@ -489,7 +496,13 @@ public class PrefetchRecordsPublisher implements RecordsPublisher { } private void sleepBeforeNextCall() throws InterruptedException { - if (lastSuccessfulCall == null) { + if (lastSuccessfulCall == null && isFirstGetCallTry) { + isFirstGetCallTry = false; + return; + } + // Add a sleep if lastSuccessfulCall is still null but this is not the first try to avoid retry storm + if(lastSuccessfulCall == null) { + Thread.sleep(idleMillisBetweenCalls); return; } long timeSinceLastCall = Duration.between(lastSuccessfulCall, Instant.now()).abs().toMillis(); diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java index f12e2310..b51b08df 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java @@ -32,6 +32,7 @@ import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -76,6 +77,7 @@ import io.reactivex.Flowable; import io.reactivex.schedulers.Schedulers; import lombok.extern.slf4j.Slf4j; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.kinesis.model.ChildShard; import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; @@ -198,7 +200,8 @@ public class PrefetchRecordsPublisherTest { .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); getRecordsCache.start(sequenceNumber, initialPosition); - ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, "shardId"), 1000L) + ProcessRecordsInput result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, + "shardId"), 1000L) .processRecordsInput(); assertEquals(expectedRecords, result.records()); @@ -208,6 +211,81 @@ public class PrefetchRecordsPublisherTest { verify(getRecordsRetrievalStrategy, atLeast(1)).getRecords(eq(MAX_RECORDS_PER_CALL)); } + @Test(expected = RuntimeException.class) + public void testGetRecordsWithInitialFailures_LessThanRequiredWait_Throws() { + // Create a new PrefetchRecordsPublisher with 1s idle time between get calls + getRecordsCache = new PrefetchRecordsPublisher( + MAX_SIZE, + 3 * SIZE_1_MB, + MAX_RECORDS_COUNT, + MAX_RECORDS_PER_CALL, + getRecordsRetrievalStrategy, + executorService, + 1000, + new NullMetricsFactory(), + operation, + "shardId"); + // Setup the retrieval strategy to fail initial calls before succeeding + when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenThrow(new + RetryableRetrievalException("Timed out")).thenThrow(new + RetryableRetrievalException("Timed out again")).thenReturn(getRecordsResponse); + record = Record.builder().data(createByteBufferWithSize(SIZE_512_KB)).build(); + + when(records.size()).thenReturn(1000); + + final List expectedRecords = records.stream() + .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); + + getRecordsCache.start(sequenceNumber, initialPosition); + ProcessRecordsInput result = null; + // Setup timeout to be less than what the PrefetchRecordsPublisher will need based on the idle time between + // get calls to validate exception is thrown + result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, + "shardId"), 1000L) + .processRecordsInput(); + } + + @Test + public void testGetRecordsWithInitialFailures_AdequateWait_Success() { + // Create a new PrefetchRecordsPublisher with 1s idle time between get calls + getRecordsCache = new PrefetchRecordsPublisher( + MAX_SIZE, + 3 * SIZE_1_MB, + MAX_RECORDS_COUNT, + MAX_RECORDS_PER_CALL, + getRecordsRetrievalStrategy, + executorService, + 1000, + new NullMetricsFactory(), + operation, + "shardId"); + // Setup the retrieval strategy to fail initial calls before succeeding + when(getRecordsRetrievalStrategy.getRecords(eq(MAX_RECORDS_PER_CALL))).thenThrow(new + RetryableRetrievalException("Timed out")).thenThrow(new + RetryableRetrievalException("Timed out again")).thenReturn(getRecordsResponse); + record = Record.builder().data(createByteBufferWithSize(SIZE_512_KB)).build(); + + when(records.size()).thenReturn(1000); + + final List expectedRecords = records.stream() + .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); + + getRecordsCache.start(sequenceNumber, initialPosition); + ProcessRecordsInput result = null; + // Setup timeout to be more than what the PrefetchRecordsPublisher will need based on the idle time between + // get calls and then validate the mocks later + result = blockUntilRecordsAvailable(() -> evictPublishedEvent(getRecordsCache, + "shardId"), 4000L) + .processRecordsInput(); + + assertEquals(expectedRecords, result.records()); + assertEquals(new ArrayList<>(), result.childShards()); + + verify(executorService).execute(any()); + // Validate at least 3 calls were including the 2 failed ones + verify(getRecordsRetrievalStrategy, atLeast(3)).getRecords(eq(MAX_RECORDS_PER_CALL)); + } + @Test public void testGetRecordsWithInvalidResponse() { record = Record.builder().data(createByteBufferWithSize(SIZE_512_KB)).build(); @@ -238,15 +316,15 @@ public class PrefetchRecordsPublisherTest { List parentShards = new ArrayList<>(); parentShards.add("shardId"); ChildShard leftChild = ChildShard.builder() - .shardId("shardId-000000000001") - .parentShards(parentShards) - .hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49")) - .build(); + .shardId("shardId-000000000001") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("0", "49")) + .build(); ChildShard rightChild = ChildShard.builder() - .shardId("shardId-000000000002") - .parentShards(parentShards) - .hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99")) - .build(); + .shardId("shardId-000000000002") + .parentShards(parentShards) + .hashKeyRange(ShardObjectHelper.newHashKeyRange("50", "99")) + .build(); childShards.add(leftChild); childShards.add(rightChild); @@ -292,9 +370,9 @@ public class PrefetchRecordsPublisherTest { sleep(2000); int callRate = (int) Math.ceil((double) MAX_RECORDS_COUNT/recordsSize); -// TODO: fix this verification -// verify(getRecordsRetrievalStrategy, times(callRate)).getRecords(MAX_RECORDS_PER_CALL); -// assertEquals(spyQueue.size(), callRate); + // TODO: fix this verification + // verify(getRecordsRetrievalStrategy, times(callRate)).getRecords(MAX_RECORDS_PER_CALL); + // assertEquals(spyQueue.size(), callRate); assertTrue("Call Rate is "+callRate,callRate < MAX_SIZE); } @@ -410,7 +488,7 @@ public class PrefetchRecordsPublisherTest { log.info("Queue is currently at {} starting subscriber", getRecordsCache.getPublisherSession().prefetchRecordsQueue().size()); AtomicInteger receivedItems = new AtomicInteger(0); - + final int expectedItems = MAX_SIZE * 10; Object lock = new Object();