From e8d2190162e76f58112e05505c6c09a5cf856ac5 Mon Sep 17 00:00:00 2001 From: Justin Pfifer Date: Thu, 16 Aug 2018 13:37:52 -0700 Subject: [PATCH] Use AFTER_SEQUENCE_NUMBER when reconnecting (#371) Subscribe to shard ends periodically and the KCL needs to reconnect at the last continuation sequence number. If the continuation sequence number happens to be the last record returned using AT_SEQUENCE_NUMBER will cause the record to be returned again. --- .../kinesis/retrieval/IteratorBuilder.java | 26 +++- .../fanout/FanOutRecordsPublisher.java | 16 ++- .../retrieval/IteratorBuilderTest.java | 20 ++- .../fanout/FanOutRecordsPublisherTest.java | 114 ++++++++++++++++++ 4 files changed, 164 insertions(+), 12 deletions(-) diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/IteratorBuilder.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/IteratorBuilder.java index e6ed3f27..2b49e031 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/IteratorBuilder.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/IteratorBuilder.java @@ -21,16 +21,31 @@ public class IteratorBuilder { return builder.startingPosition(request(StartingPosition.builder(), sequenceNumber, initialPosition).build()); } + public static SubscribeToShardRequest.Builder reconnectRequest(SubscribeToShardRequest.Builder builder, + String sequenceNumber, InitialPositionInStreamExtended initialPosition) { + return builder.startingPosition( + reconnectRequest(StartingPosition.builder(), sequenceNumber, initialPosition).build()); + } + public static StartingPosition.Builder request(StartingPosition.Builder builder, String sequenceNumber, InitialPositionInStreamExtended initialPosition) { return apply(builder, StartingPosition.Builder::type, StartingPosition.Builder::timestamp, - StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber); + StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber, + ShardIteratorType.AT_SEQUENCE_NUMBER); + } + + public static StartingPosition.Builder reconnectRequest(StartingPosition.Builder builder, String sequenceNumber, + InitialPositionInStreamExtended initialPosition) { + return apply(builder, StartingPosition.Builder::type, StartingPosition.Builder::timestamp, + StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber, + ShardIteratorType.AFTER_SEQUENCE_NUMBER); } public static GetShardIteratorRequest.Builder request(GetShardIteratorRequest.Builder builder, String sequenceNumber, InitialPositionInStreamExtended initialPosition) { return apply(builder, GetShardIteratorRequest.Builder::shardIteratorType, GetShardIteratorRequest.Builder::timestamp, - GetShardIteratorRequest.Builder::startingSequenceNumber, initialPosition, sequenceNumber); + GetShardIteratorRequest.Builder::startingSequenceNumber, initialPosition, sequenceNumber, + ShardIteratorType.AT_SEQUENCE_NUMBER); } private final static Map SHARD_ITERATOR_MAPPING; @@ -51,15 +66,16 @@ public class IteratorBuilder { private static R apply(R initial, UpdatingFunction shardIterFunc, UpdatingFunction dateFunc, UpdatingFunction sequenceFunction, - InitialPositionInStreamExtended initialPositionInStreamExtended, - String sequenceNumber) { + InitialPositionInStreamExtended initialPositionInStreamExtended, String sequenceNumber, + ShardIteratorType defaultIteratorType) { ShardIteratorType iteratorType = SHARD_ITERATOR_MAPPING.getOrDefault( - sequenceNumber, ShardIteratorType.AT_SEQUENCE_NUMBER); + sequenceNumber, defaultIteratorType); R result = shardIterFunc.apply(initial, iteratorType); switch (iteratorType) { case AT_TIMESTAMP: return dateFunc.apply(result, initialPositionInStreamExtended.getTimestamp().toInstant()); case AT_SEQUENCE_NUMBER: + case AFTER_SEQUENCE_NUMBER: return sequenceFunction.apply(result, sequenceNumber); default: return result; diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java index 5803fa16..8d9ede58 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java @@ -60,6 +60,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher { private String currentSequenceNumber; private InitialPositionInStreamExtended initialPositionInStreamExtended; + private boolean isFirstConnection = true; private Subscriber subscriber; private long outstandingRequests = 0; @@ -70,6 +71,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher { synchronized (lockObject) { this.initialPositionInStreamExtended = initialPositionInStreamExtended; this.currentSequenceNumber = extendedSequenceNumber.sequenceNumber(); + this.isFirstConnection = true; } } @@ -92,8 +94,13 @@ public class FanOutRecordsPublisher implements RecordsPublisher { synchronized (lockObject) { SubscribeToShardRequest.Builder builder = KinesisRequestsBuilder.subscribeToShardRequestBuilder() .shardId(shardId).consumerARN(consumerArn); - SubscribeToShardRequest request = IteratorBuilder - .request(builder, sequenceNumber, initialPositionInStreamExtended).build(); + SubscribeToShardRequest request; + if (isFirstConnection) { + request = IteratorBuilder.request(builder, sequenceNumber, initialPositionInStreamExtended).build(); + } else { + request = IteratorBuilder.reconnectRequest(builder, sequenceNumber, initialPositionInStreamExtended) + .build(); + } Instant connectionStart = Instant.now(); int subscribeInvocationId = subscribeToShardId.incrementAndGet(); @@ -398,6 +405,11 @@ public class FanOutRecordsPublisher implements RecordsPublisher { parent.shardId, connectionStartedAt, subscribeToShardId); subscription = new RecordSubscription(parent, this, connectionStartedAt, subscribeToShardId); publisher.subscribe(subscription); + + // + // Only flip this once we succeed + // + parent.isFirstConnection = false; } catch (Throwable t) { log.debug( "{}: [SubscriptionLifetime]: (RecordFlow#onEventStream) @ {} id: {} -- throwable during record subscription: {}", diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/IteratorBuilderTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/IteratorBuilderTest.java index 071dd661..5b04bf8d 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/IteratorBuilderTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/IteratorBuilderTest.java @@ -11,14 +11,12 @@ import java.util.function.Supplier; import org.junit.Test; -import software.amazon.awssdk.services.kinesis.model.StartingPosition; -import software.amazon.kinesis.common.InitialPositionInStream; -import software.amazon.kinesis.common.InitialPositionInStreamExtended; - import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.kinesis.checkpoint.SentinelCheckpoint; +import software.amazon.kinesis.common.InitialPositionInStream; +import software.amazon.kinesis.common.InitialPositionInStreamExtended; public class IteratorBuilderTest { @@ -53,6 +51,12 @@ public class IteratorBuilderTest { sequenceNumber(this::stsBase, this::verifyStsBase, IteratorBuilder::request, WrappedRequest::wrapped); } + @Test + public void subscribeReconnectTest() { + sequenceNumber(this::stsBase, this::verifyStsBase, IteratorBuilder::reconnectRequest, WrappedRequest::wrapped, + ShardIteratorType.AFTER_SEQUENCE_NUMBER); + } + @Test public void getShardSequenceNumberTest() { sequenceNumber(this::gsiBase, this::verifyGsiBase, IteratorBuilder::request, WrappedRequest::wrapped); @@ -68,6 +72,7 @@ public class IteratorBuilderTest { timeStampTest(this::gsiBase, this::verifyGsiBase, IteratorBuilder::request, WrappedRequest::wrapped); } + private interface IteratorApply { T apply(T base, String sequenceNumber, InitialPositionInStreamExtended initialPositionInStreamExtended); } @@ -92,10 +97,15 @@ public class IteratorBuilderTest { private void sequenceNumber(Supplier supplier, Consumer baseVerifier, IteratorApply iteratorRequest, Function> toRequest) { + sequenceNumber(supplier, baseVerifier, iteratorRequest, toRequest, ShardIteratorType.AT_SEQUENCE_NUMBER); + } + + private void sequenceNumber(Supplier supplier, Consumer baseVerifier, IteratorApply iteratorRequest, + Function> toRequest, ShardIteratorType shardIteratorType) { InitialPositionInStreamExtended initialPosition = InitialPositionInStreamExtended .newInitialPosition(InitialPositionInStream.TRIM_HORIZON); updateTest(supplier, baseVerifier, iteratorRequest, toRequest, SEQUENCE_NUMBER, initialPosition, - ShardIteratorType.AT_SEQUENCE_NUMBER, "1234", null); + shardIteratorType, "1234", null); } private void timeStampTest(Supplier supplier, Consumer baseVerifier, IteratorApply iteratorRequest, diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java index ebdd5f40..33045149 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java @@ -4,8 +4,10 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -33,6 +35,8 @@ import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; +import software.amazon.awssdk.services.kinesis.model.StartingPosition; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; @@ -218,6 +222,116 @@ public class FanOutRecordsPublisherTest { assertThat(input.records().isEmpty(), equalTo(true)); } + @Test + public void testContinuesAfterSequence() { + FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); + + ArgumentCaptor captor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordSubscription.class); + ArgumentCaptor flowCaptor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordFlow.class); + + doNothing().when(publisher).subscribe(captor.capture()); + + source.start(new ExtendedSequenceNumber("0"), + InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST)); + + NonFailingSubscriber nonFailingSubscriber = new NonFailingSubscriber(); + + source.subscribe(nonFailingSubscriber); + + SubscribeToShardRequest expected = SubscribeToShardRequest.builder().consumerARN(CONSUMER_ARN).shardId(SHARD_ID) + .startingPosition(StartingPosition.builder().sequenceNumber("0") + .type(ShardIteratorType.AT_SEQUENCE_NUMBER).build()) + .build(); + + verify(kinesisClient).subscribeToShard(eq(expected), flowCaptor.capture()); + + flowCaptor.getValue().onEventStream(publisher); + captor.getValue().onSubscribe(subscription); + + List records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList()); + List matchers = records.stream().map(KinesisClientRecordMatcher::new) + .collect(Collectors.toList()); + + batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records) + .continuationSequenceNumber("3").build(); + + captor.getValue().onNext(batchEvent); + captor.getValue().onComplete(); + flowCaptor.getValue().complete(); + + ArgumentCaptor nextSubscribeCaptor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordSubscription.class); + ArgumentCaptor nextFlowCaptor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordFlow.class); + + + SubscribeToShardRequest nextExpected = SubscribeToShardRequest.builder().consumerARN(CONSUMER_ARN).shardId(SHARD_ID) + .startingPosition(StartingPosition.builder().sequenceNumber("3") + .type(ShardIteratorType.AFTER_SEQUENCE_NUMBER).build()) + .build(); + + verify(kinesisClient).subscribeToShard(eq(nextExpected), nextFlowCaptor.capture()); + reset(publisher); + doNothing().when(publisher).subscribe(nextSubscribeCaptor.capture()); + + nextFlowCaptor.getValue().onEventStream(publisher); + nextSubscribeCaptor.getValue().onSubscribe(subscription); + + + List nextRecords = Stream.of(4, 5, 6).map(this::makeRecord).collect(Collectors.toList()); + List nextMatchers = nextRecords.stream().map(KinesisClientRecordMatcher::new) + .collect(Collectors.toList()); + + batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(nextRecords) + .continuationSequenceNumber("6").build(); + nextSubscribeCaptor.getValue().onNext(batchEvent); + + verify(subscription, times(4)).request(1); + + assertThat(nonFailingSubscriber.received.size(), equalTo(2)); + + verifyRecords(nonFailingSubscriber.received.get(0).records(), matchers); + verifyRecords(nonFailingSubscriber.received.get(1).records(), nextMatchers); + + } + + private void verifyRecords(List clientRecordsList, List matchers) { + assertThat(clientRecordsList.size(), equalTo(matchers.size())); + for (int i = 0; i < clientRecordsList.size(); ++i) { + assertThat(clientRecordsList.get(i), matchers.get(i)); + } + } + + private static class NonFailingSubscriber implements Subscriber { + final List received = new ArrayList<>(); + Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + subscription = s; + subscription.request(1); + } + + @Override + public void onNext(ProcessRecordsInput input) { + received.add(input); + subscription.request(1); + } + + @Override + public void onError(Throwable t) { + log.error("Caught throwable in subscriber", t); + fail("Caught throwable in subscriber"); + } + + @Override + public void onComplete() { + fail("OnComplete called when not expected"); + } + } + private Record makeRecord(int sequenceNumber) { SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 }); return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now())