diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/IteratorBuilder.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/IteratorBuilder.java index e6ed3f27..2b49e031 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/IteratorBuilder.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/IteratorBuilder.java @@ -21,16 +21,31 @@ public class IteratorBuilder { return builder.startingPosition(request(StartingPosition.builder(), sequenceNumber, initialPosition).build()); } + public static SubscribeToShardRequest.Builder reconnectRequest(SubscribeToShardRequest.Builder builder, + String sequenceNumber, InitialPositionInStreamExtended initialPosition) { + return builder.startingPosition( + reconnectRequest(StartingPosition.builder(), sequenceNumber, initialPosition).build()); + } + public static StartingPosition.Builder request(StartingPosition.Builder builder, String sequenceNumber, InitialPositionInStreamExtended initialPosition) { return apply(builder, StartingPosition.Builder::type, StartingPosition.Builder::timestamp, - StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber); + StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber, + ShardIteratorType.AT_SEQUENCE_NUMBER); + } + + public static StartingPosition.Builder reconnectRequest(StartingPosition.Builder builder, String sequenceNumber, + InitialPositionInStreamExtended initialPosition) { + return apply(builder, StartingPosition.Builder::type, StartingPosition.Builder::timestamp, + StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber, + ShardIteratorType.AFTER_SEQUENCE_NUMBER); } public static GetShardIteratorRequest.Builder request(GetShardIteratorRequest.Builder builder, String sequenceNumber, InitialPositionInStreamExtended initialPosition) { return apply(builder, GetShardIteratorRequest.Builder::shardIteratorType, GetShardIteratorRequest.Builder::timestamp, - GetShardIteratorRequest.Builder::startingSequenceNumber, initialPosition, sequenceNumber); + GetShardIteratorRequest.Builder::startingSequenceNumber, initialPosition, sequenceNumber, + ShardIteratorType.AT_SEQUENCE_NUMBER); } private final static Map SHARD_ITERATOR_MAPPING; @@ -51,15 +66,16 @@ public class IteratorBuilder { private static R apply(R initial, UpdatingFunction shardIterFunc, UpdatingFunction dateFunc, UpdatingFunction sequenceFunction, - InitialPositionInStreamExtended initialPositionInStreamExtended, - String sequenceNumber) { + InitialPositionInStreamExtended initialPositionInStreamExtended, String sequenceNumber, + ShardIteratorType defaultIteratorType) { ShardIteratorType iteratorType = SHARD_ITERATOR_MAPPING.getOrDefault( - sequenceNumber, ShardIteratorType.AT_SEQUENCE_NUMBER); + sequenceNumber, defaultIteratorType); R result = shardIterFunc.apply(initial, iteratorType); switch (iteratorType) { case AT_TIMESTAMP: return dateFunc.apply(result, initialPositionInStreamExtended.getTimestamp().toInstant()); case AT_SEQUENCE_NUMBER: + case AFTER_SEQUENCE_NUMBER: return sequenceFunction.apply(result, sequenceNumber); default: return result; diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java index 5803fa16..8d9ede58 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java @@ -60,6 +60,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher { private String currentSequenceNumber; private InitialPositionInStreamExtended initialPositionInStreamExtended; + private boolean isFirstConnection = true; private Subscriber subscriber; private long outstandingRequests = 0; @@ -70,6 +71,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher { synchronized (lockObject) { this.initialPositionInStreamExtended = initialPositionInStreamExtended; this.currentSequenceNumber = extendedSequenceNumber.sequenceNumber(); + this.isFirstConnection = true; } } @@ -92,8 +94,13 @@ public class FanOutRecordsPublisher implements RecordsPublisher { synchronized (lockObject) { SubscribeToShardRequest.Builder builder = KinesisRequestsBuilder.subscribeToShardRequestBuilder() .shardId(shardId).consumerARN(consumerArn); - SubscribeToShardRequest request = IteratorBuilder - .request(builder, sequenceNumber, initialPositionInStreamExtended).build(); + SubscribeToShardRequest request; + if (isFirstConnection) { + request = IteratorBuilder.request(builder, sequenceNumber, initialPositionInStreamExtended).build(); + } else { + request = IteratorBuilder.reconnectRequest(builder, sequenceNumber, initialPositionInStreamExtended) + .build(); + } Instant connectionStart = Instant.now(); int subscribeInvocationId = subscribeToShardId.incrementAndGet(); @@ -398,6 +405,11 @@ public class FanOutRecordsPublisher implements RecordsPublisher { parent.shardId, connectionStartedAt, subscribeToShardId); subscription = new RecordSubscription(parent, this, connectionStartedAt, subscribeToShardId); publisher.subscribe(subscription); + + // + // Only flip this once we succeed + // + parent.isFirstConnection = false; } catch (Throwable t) { log.debug( "{}: [SubscriptionLifetime]: (RecordFlow#onEventStream) @ {} id: {} -- throwable during record subscription: {}", diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/IteratorBuilderTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/IteratorBuilderTest.java index 071dd661..5b04bf8d 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/IteratorBuilderTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/IteratorBuilderTest.java @@ -11,14 +11,12 @@ import java.util.function.Supplier; import org.junit.Test; -import software.amazon.awssdk.services.kinesis.model.StartingPosition; -import software.amazon.kinesis.common.InitialPositionInStream; -import software.amazon.kinesis.common.InitialPositionInStreamExtended; - import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.kinesis.checkpoint.SentinelCheckpoint; +import software.amazon.kinesis.common.InitialPositionInStream; +import software.amazon.kinesis.common.InitialPositionInStreamExtended; public class IteratorBuilderTest { @@ -53,6 +51,12 @@ public class IteratorBuilderTest { sequenceNumber(this::stsBase, this::verifyStsBase, IteratorBuilder::request, WrappedRequest::wrapped); } + @Test + public void subscribeReconnectTest() { + sequenceNumber(this::stsBase, this::verifyStsBase, IteratorBuilder::reconnectRequest, WrappedRequest::wrapped, + ShardIteratorType.AFTER_SEQUENCE_NUMBER); + } + @Test public void getShardSequenceNumberTest() { sequenceNumber(this::gsiBase, this::verifyGsiBase, IteratorBuilder::request, WrappedRequest::wrapped); @@ -68,6 +72,7 @@ public class IteratorBuilderTest { timeStampTest(this::gsiBase, this::verifyGsiBase, IteratorBuilder::request, WrappedRequest::wrapped); } + private interface IteratorApply { T apply(T base, String sequenceNumber, InitialPositionInStreamExtended initialPositionInStreamExtended); } @@ -92,10 +97,15 @@ public class IteratorBuilderTest { private void sequenceNumber(Supplier supplier, Consumer baseVerifier, IteratorApply iteratorRequest, Function> toRequest) { + sequenceNumber(supplier, baseVerifier, iteratorRequest, toRequest, ShardIteratorType.AT_SEQUENCE_NUMBER); + } + + private void sequenceNumber(Supplier supplier, Consumer baseVerifier, IteratorApply iteratorRequest, + Function> toRequest, ShardIteratorType shardIteratorType) { InitialPositionInStreamExtended initialPosition = InitialPositionInStreamExtended .newInitialPosition(InitialPositionInStream.TRIM_HORIZON); updateTest(supplier, baseVerifier, iteratorRequest, toRequest, SEQUENCE_NUMBER, initialPosition, - ShardIteratorType.AT_SEQUENCE_NUMBER, "1234", null); + shardIteratorType, "1234", null); } private void timeStampTest(Supplier supplier, Consumer baseVerifier, IteratorApply iteratorRequest, diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java index ebdd5f40..33045149 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java @@ -4,8 +4,10 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -33,6 +35,8 @@ import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; +import software.amazon.awssdk.services.kinesis.model.StartingPosition; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; @@ -218,6 +222,116 @@ public class FanOutRecordsPublisherTest { assertThat(input.records().isEmpty(), equalTo(true)); } + @Test + public void testContinuesAfterSequence() { + FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); + + ArgumentCaptor captor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordSubscription.class); + ArgumentCaptor flowCaptor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordFlow.class); + + doNothing().when(publisher).subscribe(captor.capture()); + + source.start(new ExtendedSequenceNumber("0"), + InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST)); + + NonFailingSubscriber nonFailingSubscriber = new NonFailingSubscriber(); + + source.subscribe(nonFailingSubscriber); + + SubscribeToShardRequest expected = SubscribeToShardRequest.builder().consumerARN(CONSUMER_ARN).shardId(SHARD_ID) + .startingPosition(StartingPosition.builder().sequenceNumber("0") + .type(ShardIteratorType.AT_SEQUENCE_NUMBER).build()) + .build(); + + verify(kinesisClient).subscribeToShard(eq(expected), flowCaptor.capture()); + + flowCaptor.getValue().onEventStream(publisher); + captor.getValue().onSubscribe(subscription); + + List records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList()); + List matchers = records.stream().map(KinesisClientRecordMatcher::new) + .collect(Collectors.toList()); + + batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records) + .continuationSequenceNumber("3").build(); + + captor.getValue().onNext(batchEvent); + captor.getValue().onComplete(); + flowCaptor.getValue().complete(); + + ArgumentCaptor nextSubscribeCaptor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordSubscription.class); + ArgumentCaptor nextFlowCaptor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordFlow.class); + + + SubscribeToShardRequest nextExpected = SubscribeToShardRequest.builder().consumerARN(CONSUMER_ARN).shardId(SHARD_ID) + .startingPosition(StartingPosition.builder().sequenceNumber("3") + .type(ShardIteratorType.AFTER_SEQUENCE_NUMBER).build()) + .build(); + + verify(kinesisClient).subscribeToShard(eq(nextExpected), nextFlowCaptor.capture()); + reset(publisher); + doNothing().when(publisher).subscribe(nextSubscribeCaptor.capture()); + + nextFlowCaptor.getValue().onEventStream(publisher); + nextSubscribeCaptor.getValue().onSubscribe(subscription); + + + List nextRecords = Stream.of(4, 5, 6).map(this::makeRecord).collect(Collectors.toList()); + List nextMatchers = nextRecords.stream().map(KinesisClientRecordMatcher::new) + .collect(Collectors.toList()); + + batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(nextRecords) + .continuationSequenceNumber("6").build(); + nextSubscribeCaptor.getValue().onNext(batchEvent); + + verify(subscription, times(4)).request(1); + + assertThat(nonFailingSubscriber.received.size(), equalTo(2)); + + verifyRecords(nonFailingSubscriber.received.get(0).records(), matchers); + verifyRecords(nonFailingSubscriber.received.get(1).records(), nextMatchers); + + } + + private void verifyRecords(List clientRecordsList, List matchers) { + assertThat(clientRecordsList.size(), equalTo(matchers.size())); + for (int i = 0; i < clientRecordsList.size(); ++i) { + assertThat(clientRecordsList.get(i), matchers.get(i)); + } + } + + private static class NonFailingSubscriber implements Subscriber { + final List received = new ArrayList<>(); + Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + subscription = s; + subscription.request(1); + } + + @Override + public void onNext(ProcessRecordsInput input) { + received.add(input); + subscription.request(1); + } + + @Override + public void onError(Throwable t) { + log.error("Caught throwable in subscriber", t); + fail("Caught throwable in subscriber"); + } + + @Override + public void onComplete() { + fail("OnComplete called when not expected"); + } + } + private Record makeRecord(int sequenceNumber) { SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 }); return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now())