Use AFTER_SEQUENCE_NUMBER when reconnecting (#371)

Subscribe to shard ends periodically and the KCL needs to reconnect at
the last continuation sequence number.  If the continuation sequence
number happens to be the last record returned using AT_SEQUENCE_NUMBER
will cause the record to be returned again.
This commit is contained in:
Justin Pfifer 2018-08-16 13:37:52 -07:00 committed by Sahil Palvia
parent e694ab7724
commit e8d2190162
4 changed files with 164 additions and 12 deletions

View file

@ -21,16 +21,31 @@ public class IteratorBuilder {
return builder.startingPosition(request(StartingPosition.builder(), sequenceNumber, initialPosition).build());
}
public static SubscribeToShardRequest.Builder reconnectRequest(SubscribeToShardRequest.Builder builder,
String sequenceNumber, InitialPositionInStreamExtended initialPosition) {
return builder.startingPosition(
reconnectRequest(StartingPosition.builder(), sequenceNumber, initialPosition).build());
}
public static StartingPosition.Builder request(StartingPosition.Builder builder, String sequenceNumber,
InitialPositionInStreamExtended initialPosition) {
return apply(builder, StartingPosition.Builder::type, StartingPosition.Builder::timestamp,
StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber);
StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber,
ShardIteratorType.AT_SEQUENCE_NUMBER);
}
public static StartingPosition.Builder reconnectRequest(StartingPosition.Builder builder, String sequenceNumber,
InitialPositionInStreamExtended initialPosition) {
return apply(builder, StartingPosition.Builder::type, StartingPosition.Builder::timestamp,
StartingPosition.Builder::sequenceNumber, initialPosition, sequenceNumber,
ShardIteratorType.AFTER_SEQUENCE_NUMBER);
}
public static GetShardIteratorRequest.Builder request(GetShardIteratorRequest.Builder builder,
String sequenceNumber, InitialPositionInStreamExtended initialPosition) {
return apply(builder, GetShardIteratorRequest.Builder::shardIteratorType, GetShardIteratorRequest.Builder::timestamp,
GetShardIteratorRequest.Builder::startingSequenceNumber, initialPosition, sequenceNumber);
GetShardIteratorRequest.Builder::startingSequenceNumber, initialPosition, sequenceNumber,
ShardIteratorType.AT_SEQUENCE_NUMBER);
}
private final static Map<String, ShardIteratorType> SHARD_ITERATOR_MAPPING;
@ -51,15 +66,16 @@ public class IteratorBuilder {
private static <R> R apply(R initial, UpdatingFunction<ShardIteratorType, R> shardIterFunc,
UpdatingFunction<Instant, R> dateFunc, UpdatingFunction<String, R> sequenceFunction,
InitialPositionInStreamExtended initialPositionInStreamExtended,
String sequenceNumber) {
InitialPositionInStreamExtended initialPositionInStreamExtended, String sequenceNumber,
ShardIteratorType defaultIteratorType) {
ShardIteratorType iteratorType = SHARD_ITERATOR_MAPPING.getOrDefault(
sequenceNumber, ShardIteratorType.AT_SEQUENCE_NUMBER);
sequenceNumber, defaultIteratorType);
R result = shardIterFunc.apply(initial, iteratorType);
switch (iteratorType) {
case AT_TIMESTAMP:
return dateFunc.apply(result, initialPositionInStreamExtended.getTimestamp().toInstant());
case AT_SEQUENCE_NUMBER:
case AFTER_SEQUENCE_NUMBER:
return sequenceFunction.apply(result, sequenceNumber);
default:
return result;

View file

@ -60,6 +60,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
private String currentSequenceNumber;
private InitialPositionInStreamExtended initialPositionInStreamExtended;
private boolean isFirstConnection = true;
private Subscriber<? super ProcessRecordsInput> subscriber;
private long outstandingRequests = 0;
@ -70,6 +71,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
synchronized (lockObject) {
this.initialPositionInStreamExtended = initialPositionInStreamExtended;
this.currentSequenceNumber = extendedSequenceNumber.sequenceNumber();
this.isFirstConnection = true;
}
}
@ -92,8 +94,13 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
synchronized (lockObject) {
SubscribeToShardRequest.Builder builder = KinesisRequestsBuilder.subscribeToShardRequestBuilder()
.shardId(shardId).consumerARN(consumerArn);
SubscribeToShardRequest request = IteratorBuilder
.request(builder, sequenceNumber, initialPositionInStreamExtended).build();
SubscribeToShardRequest request;
if (isFirstConnection) {
request = IteratorBuilder.request(builder, sequenceNumber, initialPositionInStreamExtended).build();
} else {
request = IteratorBuilder.reconnectRequest(builder, sequenceNumber, initialPositionInStreamExtended)
.build();
}
Instant connectionStart = Instant.now();
int subscribeInvocationId = subscribeToShardId.incrementAndGet();
@ -398,6 +405,11 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
parent.shardId, connectionStartedAt, subscribeToShardId);
subscription = new RecordSubscription(parent, this, connectionStartedAt, subscribeToShardId);
publisher.subscribe(subscription);
//
// Only flip this once we succeed
//
parent.isFirstConnection = false;
} catch (Throwable t) {
log.debug(
"{}: [SubscriptionLifetime]: (RecordFlow#onEventStream) @ {} id: {} -- throwable during record subscription: {}",

View file

@ -11,14 +11,12 @@ import java.util.function.Supplier;
import org.junit.Test;
import software.amazon.awssdk.services.kinesis.model.StartingPosition;
import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.kinesis.checkpoint.SentinelCheckpoint;
import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
public class IteratorBuilderTest {
@ -53,6 +51,12 @@ public class IteratorBuilderTest {
sequenceNumber(this::stsBase, this::verifyStsBase, IteratorBuilder::request, WrappedRequest::wrapped);
}
@Test
public void subscribeReconnectTest() {
sequenceNumber(this::stsBase, this::verifyStsBase, IteratorBuilder::reconnectRequest, WrappedRequest::wrapped,
ShardIteratorType.AFTER_SEQUENCE_NUMBER);
}
@Test
public void getShardSequenceNumberTest() {
sequenceNumber(this::gsiBase, this::verifyGsiBase, IteratorBuilder::request, WrappedRequest::wrapped);
@ -68,6 +72,7 @@ public class IteratorBuilderTest {
timeStampTest(this::gsiBase, this::verifyGsiBase, IteratorBuilder::request, WrappedRequest::wrapped);
}
private interface IteratorApply<T> {
T apply(T base, String sequenceNumber, InitialPositionInStreamExtended initialPositionInStreamExtended);
}
@ -92,10 +97,15 @@ public class IteratorBuilderTest {
private <T, R> void sequenceNumber(Supplier<T> supplier, Consumer<R> baseVerifier, IteratorApply<T> iteratorRequest,
Function<T, WrappedRequest<R>> toRequest) {
sequenceNumber(supplier, baseVerifier, iteratorRequest, toRequest, ShardIteratorType.AT_SEQUENCE_NUMBER);
}
private <T, R> void sequenceNumber(Supplier<T> supplier, Consumer<R> baseVerifier, IteratorApply<T> iteratorRequest,
Function<T, WrappedRequest<R>> toRequest, ShardIteratorType shardIteratorType) {
InitialPositionInStreamExtended initialPosition = InitialPositionInStreamExtended
.newInitialPosition(InitialPositionInStream.TRIM_HORIZON);
updateTest(supplier, baseVerifier, iteratorRequest, toRequest, SEQUENCE_NUMBER, initialPosition,
ShardIteratorType.AT_SEQUENCE_NUMBER, "1234", null);
shardIteratorType, "1234", null);
}
private <T, R> void timeStampTest(Supplier<T> supplier, Consumer<R> baseVerifier, IteratorApply<T> iteratorRequest,

View file

@ -4,8 +4,10 @@ import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -33,6 +35,8 @@ import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;
import software.amazon.awssdk.services.kinesis.model.StartingPosition;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
@ -218,6 +222,116 @@ public class FanOutRecordsPublisherTest {
assertThat(input.records().isEmpty(), equalTo(true));
}
@Test
public void testContinuesAfterSequence() {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordFlow.class);
doNothing().when(publisher).subscribe(captor.capture());
source.start(new ExtendedSequenceNumber("0"),
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));
NonFailingSubscriber nonFailingSubscriber = new NonFailingSubscriber();
source.subscribe(nonFailingSubscriber);
SubscribeToShardRequest expected = SubscribeToShardRequest.builder().consumerARN(CONSUMER_ARN).shardId(SHARD_ID)
.startingPosition(StartingPosition.builder().sequenceNumber("0")
.type(ShardIteratorType.AT_SEQUENCE_NUMBER).build())
.build();
verify(kinesisClient).subscribeToShard(eq(expected), flowCaptor.capture());
flowCaptor.getValue().onEventStream(publisher);
captor.getValue().onSubscribe(subscription);
List<Record> records = Stream.of(1, 2, 3).map(this::makeRecord).collect(Collectors.toList());
List<KinesisClientRecordMatcher> matchers = records.stream().map(KinesisClientRecordMatcher::new)
.collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records)
.continuationSequenceNumber("3").build();
captor.getValue().onNext(batchEvent);
captor.getValue().onComplete();
flowCaptor.getValue().complete();
ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> nextSubscribeCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> nextFlowCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordFlow.class);
SubscribeToShardRequest nextExpected = SubscribeToShardRequest.builder().consumerARN(CONSUMER_ARN).shardId(SHARD_ID)
.startingPosition(StartingPosition.builder().sequenceNumber("3")
.type(ShardIteratorType.AFTER_SEQUENCE_NUMBER).build())
.build();
verify(kinesisClient).subscribeToShard(eq(nextExpected), nextFlowCaptor.capture());
reset(publisher);
doNothing().when(publisher).subscribe(nextSubscribeCaptor.capture());
nextFlowCaptor.getValue().onEventStream(publisher);
nextSubscribeCaptor.getValue().onSubscribe(subscription);
List<Record> nextRecords = Stream.of(4, 5, 6).map(this::makeRecord).collect(Collectors.toList());
List<KinesisClientRecordMatcher> nextMatchers = nextRecords.stream().map(KinesisClientRecordMatcher::new)
.collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(nextRecords)
.continuationSequenceNumber("6").build();
nextSubscribeCaptor.getValue().onNext(batchEvent);
verify(subscription, times(4)).request(1);
assertThat(nonFailingSubscriber.received.size(), equalTo(2));
verifyRecords(nonFailingSubscriber.received.get(0).records(), matchers);
verifyRecords(nonFailingSubscriber.received.get(1).records(), nextMatchers);
}
private void verifyRecords(List<KinesisClientRecord> clientRecordsList, List<KinesisClientRecordMatcher> matchers) {
assertThat(clientRecordsList.size(), equalTo(matchers.size()));
for (int i = 0; i < clientRecordsList.size(); ++i) {
assertThat(clientRecordsList.get(i), matchers.get(i));
}
}
private static class NonFailingSubscriber implements Subscriber<ProcessRecordsInput> {
final List<ProcessRecordsInput> received = new ArrayList<>();
Subscription subscription;
@Override
public void onSubscribe(Subscription s) {
subscription = s;
subscription.request(1);
}
@Override
public void onNext(ProcessRecordsInput input) {
received.add(input);
subscription.request(1);
}
@Override
public void onError(Throwable t) {
log.error("Caught throwable in subscriber", t);
fail("Caught throwable in subscriber");
}
@Override
public void onComplete() {
fail("OnComplete called when not expected");
}
}
private Record makeRecord(int sequenceNumber) {
SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 });
return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now())