diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java new file mode 100644 index 00000000..e18da9ec --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java @@ -0,0 +1,188 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.kinesis.checkpoint; + +import java.math.BigInteger; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.commons.lang3.StringUtils; + +import lombok.Data; +import lombok.experimental.Accessors; + +/** + * This supports extracting the shardId from a sequence number. + * + *

Warning

+ * Sequence numbers are an opaque value used by Kinesis, and maybe changed at any time. Should validation stop + * working you may need to update your version of the KCL + * + */ +public class SequenceNumberValidator { + + @Data + @Accessors(fluent = true) + private static class SequenceNumberComponents { + final int version; + final int shardId; + } + + private interface SequenceNumberReader { + Optional read(String sequenceNumber); + } + + /** + * Reader for the v2 sequence number format. v1 sequence numbers are no longer used or available. + */ + private static class V2SequenceNumberReader implements SequenceNumberReader { + + private static final int VERSION = 2; + + private static final int EXPECTED_BIT_LENGTH = 186; + + private static final int VERSION_OFFSET = 184; + private static final long VERSION_MASK = (1 << 4) - 1; + + private static final int SHARD_ID_OFFSET = 4; + private static final long SHARD_ID_MASK = (1L << 32) - 1; + + @Override + public Optional read(String sequenceNumberString) { + BigInteger sequenceNumber = new BigInteger(sequenceNumberString, 10); + + // + // If the bit length of the sequence number isn't 186 it's impossible for the version numbers + // to be where we expect them. We treat this the same as an unknown version of the sequence number + // + // If the sequence number length isn't what we expect it's due to a new version of the sequence number or + // an invalid sequence number. This + // + if (sequenceNumber.bitLength() != EXPECTED_BIT_LENGTH) { + return Optional.empty(); + } + + // + // Read the 4 most significant bits of the sequence number, the 2 most significant bits are implicitly 0 + // (2 == 0b0011). If the version number doesn't match we give up and say we can't parse the sequence number + // + int version = readOffset(sequenceNumber, VERSION_OFFSET, VERSION_MASK); + if (version != VERSION) { + return Optional.empty(); + } + + // + // If we get here the sequence number is big enough, and the version matches so the shardId should be valid. + // + int shardId = readOffset(sequenceNumber, SHARD_ID_OFFSET, SHARD_ID_MASK); + return Optional.of(new SequenceNumberComponents(version, shardId)); + } + + private int readOffset(BigInteger sequenceNumber, int offset, long mask) { + long value = sequenceNumber.shiftRight(offset).longValue() & mask; + return (int) value; + } + } + + private static final List SEQUENCE_NUMBER_READERS = Collections + .singletonList(new V2SequenceNumberReader()); + + private Optional retrieveComponentsFor(String sequenceNumber) { + return SEQUENCE_NUMBER_READERS.stream().map(r -> r.read(sequenceNumber)).filter(Optional::isPresent).map(Optional::get).findFirst(); + } + + /** + * Attempts to retrieve the version for a sequence number. If no reader can be found for the sequence number this + * will return an empty Optional. + * + *

+ * This will return an empty Optional if the it's unable to extract the version number. This can occur for + * multiple reasons including: + *

    + *
  • Kinesis has started using a new version of sequence numbers
  • + *
  • The provided sequence number isn't a valid Kinesis sequence number.
  • + *
+ * + *

+ * + * @param sequenceNumber + * the sequence number to extract the version from + * @return an Optional containing the version if a compatible sequence number reader can be found, an empty Optional + * otherwise. + */ + public Optional versionFor(String sequenceNumber) { + return retrieveComponentsFor(sequenceNumber).map(SequenceNumberComponents::version); + } + + /** + * Attempts to retrieve the shardId from a sequence number. If the version of the sequence number is unsupported + * this will return an empty optional. + * + * This will return an empty Optional if the sequence number isn't recognized. This can occur for multiple + * reasons including: + *
    + *
  • Kinesis has started using a new version of sequence numbers
  • + *
  • The provided sequence number isn't a valid Kinesis sequence number.
  • + *
+ *
+ *

+ * This should always return a value if {@link #versionFor(String)} returns a value + *

+ * + * @param sequenceNumber + * the sequence number to extract the shardId from + * @return an Optional containing the shardId if the version is supported, an empty Optional otherwise. + */ + public Optional shardIdFor(String sequenceNumber) { + return retrieveComponentsFor(sequenceNumber).map(s -> String.format("shardId-%012d", s.shardId())); + } + + /** + * Validates that the sequence number provided contains the given shardId. If the sequence number is unsupported + * this will return an empty Optional. + * + *

+ * Validation of a sequence number will only occur if the sequence number can be parsed. It's possible to use + * {@link #versionFor(String)} to verify that the given sequence number is supported by this class. There are 3 + * possible validation states: + *

+ *
Some(True)
+ *
The sequence number can be parsed, and the shardId matches the one in the sequence number
+ *
Some(False)
+ *
THe sequence number can be parsed, and the shardId doesn't match the one in the sequence number
+ *
None
+ *
It wasn't possible to parse the sequence number so the validity of the sequence number is unknown
+ *
+ *

+ * + *

+ * Handling unknown validation causes is application specific, and not specific handling is + * provided. + *

+ * + * @param sequenceNumber + * the sequence number to verify the shardId + * @param shardId + * the shardId that the sequence is expected to contain + * @return true if the sequence number contains the shardId, false if it doesn't. If the sequence number version is + * unsupported this will return an empty Optional + */ + public Optional validateSequenceNumberForShard(String sequenceNumber, String shardId) { + return shardIdFor(sequenceNumber).map(s -> StringUtils.equalsIgnoreCase(s, shardId)); + } + +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java index 33f519f9..173a871e 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java @@ -80,9 +80,15 @@ public class FanOutConfig implements RetrievalSpecificConfig { */ private long retryBackoffMillis = 1000; + /** + * Controls whether the {@link FanOutRecordsPublisher} will validate that all the records are from the shard it's + * processing. + */ + private boolean validateRecordsAreForShard = false; + @Override public RetrievalFactory retrievalFactory() { - return new FanOutRetrievalFactory(kinesisClient, getOrCreateConsumerArn()); + return new FanOutRetrievalFactory(kinesisClient, getOrCreateConsumerArn()).validateRecordsAreForShard(validateRecordsAreForShard); } private String getOrCreateConsumerArn() { diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java index 728aafc0..bb9c0c75 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java @@ -18,15 +18,20 @@ package software.amazon.kinesis.retrieval.fanout; import java.time.Instant; import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; @@ -35,6 +40,7 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; import software.amazon.kinesis.annotations.KinesisClientInternalApi; +import software.amazon.kinesis.checkpoint.SequenceNumberValidator; import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.KinesisRequestsBuilder; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; @@ -43,7 +49,6 @@ import software.amazon.kinesis.retrieval.KinesisClientRecord; import software.amazon.kinesis.retrieval.RecordsPublisher; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; -@RequiredArgsConstructor @Slf4j @KinesisClientInternalApi public class FanOutRecordsPublisher implements RecordsPublisher { @@ -51,6 +56,34 @@ public class FanOutRecordsPublisher implements RecordsPublisher { private final KinesisAsyncClient kinesis; private final String shardId; private final String consumerArn; + private final boolean validateRecordShardMatching; + + /** + * Creates a new FanOutRecordsPublisher. + *

+ * This is deprecated and will be removed in a later release. Use + * {@link #FanOutRecordsPublisher(KinesisAsyncClient, String, String, boolean)} instead + *

+ * + * @param kinesis + * the kinesis client to use for requests + * @param shardId + * the shardId to retrieve records for + * @param consumerArn + * the consumer to use when retrieving records + */ + @Deprecated + public FanOutRecordsPublisher(KinesisAsyncClient kinesis, String shardId, String consumerArn) { + this(kinesis, shardId, consumerArn, false); + } + + public FanOutRecordsPublisher(KinesisAsyncClient kinesis, String shardId, String consumerArn, + boolean validateRecordShardMatching) { + this.kinesis = kinesis; + this.shardId = shardId; + this.consumerArn = consumerArn; + this.validateRecordShardMatching = validateRecordShardMatching; + } private final Object lockObject = new Object(); @@ -451,8 +484,15 @@ public class FanOutRecordsPublisher implements RecordsPublisher { @Override public void responseReceived(SubscribeToShardResponse response) { - log.debug("{}: [SubscriptionLifetime]: (RecordFlow#responseReceived) @ {} id: {} -- Response received", - parent.shardId, connectionStartedAt, subscribeToShardId); + Optional sdkHttpResponse = Optional.ofNullable(response) + .flatMap(r -> Optional.ofNullable(r.sdkHttpResponse())); + Optional requestId = sdkHttpResponse.flatMap(s -> s.firstMatchingHeader("x-amz-requestid")); + Optional requestId2 = sdkHttpResponse.flatMap(s -> s.firstMatchingHeader("x-amz-id-2")); + + log.debug( + "{}: [SubscriptionLifetime]: (RecordFlow#responseReceived) @ {} id: {} -- Response received -- rid: {} -- rid2: {}", + parent.shardId, connectionStartedAt, subscribeToShardId, requestId.orElse("None"), + requestId2.orElse("None")); } @Override @@ -548,6 +588,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher { private final RecordFlow flow; private final Instant connectionStartedAt; private final String subscribeToShardId; + private final SequenceNumberValidator sequenceNumberValidator = new SequenceNumberValidator(); private Subscription subscription; @@ -594,8 +635,9 @@ public class FanOutRecordsPublisher implements RecordsPublisher { cancel(); } log.debug( - "{}: [SubscriptionLifetime]: (RecordSubscription#onSubscribe) @ {} id: {} -- Outstanding: {} items so requesting an item", - parent.shardId, connectionStartedAt, subscribeToShardId, parent.availableQueueSpace); + "{}: [SubscriptionLifetime]: (RecordSubscription#onSubscribe) @ {} id: {} (Subscription ObjectId: {}) -- Outstanding: {} items so requesting an item", + parent.shardId, connectionStartedAt, subscribeToShardId, System.identityHashCode(subscription), + parent.availableQueueSpace); if (parent.availableQueueSpace > 0) { request(1); } @@ -615,12 +657,54 @@ public class FanOutRecordsPublisher implements RecordsPublisher { recordBatchEvent.accept(new SubscribeToShardResponseHandler.Visitor() { @Override public void visit(SubscribeToShardEvent event) { + if (parent.validateRecordShardMatching && !areRecordsValid(event)) { + return; + } flow.recordsReceived(event); } }); } } + private boolean areRecordsValid(SubscribeToShardEvent event) { + try { + Map mismatchedRecords = recordsNotForShard(event); + if (mismatchedRecords.size() > 0) { + String mismatchReport = mismatchedRecords.entrySet().stream() + .map(e -> String.format("(%s -> %d)", e.getKey(), e.getValue())) + .collect(Collectors.joining(", ")); + log.debug( + "{}: [SubscriptionLifetime]: (RecordSubscription#onNext#vistor) @ {} id: {} (Subscription ObjectId: {}) -- Failing subscription due to mismatches: [ {} ]", + parent.shardId, connectionStartedAt, subscribeToShardId, + System.identityHashCode(subscription), mismatchReport); + parent.errorOccurred(flow, new IllegalArgumentException( + "Received records destined for different shards: " + mismatchReport)); + return false; + } + } catch (IllegalArgumentException iae) { + log.debug( + "{}: [SubscriptionLifetime]: (RecordSubscription#onNext#vistor) @ {} id: {} (Subscription ObjectId: {}) -- " + + "A problem occurred while validating sequence numbers: {} on subscription {}", + parent.shardId, connectionStartedAt, subscribeToShardId, System.identityHashCode(subscription), + iae.getMessage(), iae); + parent.errorOccurred(flow, iae); + return false; + } + return true; + } + + private Map recordsNotForShard(SubscribeToShardEvent event) { + return event.records().stream().map(r -> { + Optional res = sequenceNumberValidator.shardIdFor(r.sequenceNumber()); + if (!res.isPresent()) { + throw new IllegalArgumentException("Unable to validate sequence number of " + r.sequenceNumber()); + } + return res.get(); + }).filter(s -> !StringUtils.equalsIgnoreCase(s, parent.shardId)) + .collect(Collectors.groupingBy(Function.identity())).entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().size())); + } + @Override public void onError(Throwable t) { log.debug("{}: [SubscriptionLifetime]: (RecordSubscription#onError) @ {} id: {} -- {}: {}", parent.shardId, diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java index eea61250..46be077f 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java @@ -15,8 +15,11 @@ package software.amazon.kinesis.retrieval.fanout; +import lombok.Getter; import lombok.NonNull; import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.leases.ShardInfo; @@ -27,10 +30,14 @@ import software.amazon.kinesis.retrieval.RetrievalFactory; @RequiredArgsConstructor @KinesisClientInternalApi +@Accessors(fluent = true) public class FanOutRetrievalFactory implements RetrievalFactory { private final KinesisAsyncClient kinesisClient; private final String consumerArn; + @Getter + @Setter + private boolean validateRecordsAreForShard = false; @Override public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo, @@ -41,6 +48,6 @@ public class FanOutRetrievalFactory implements RetrievalFactory { @Override public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo, final MetricsFactory metricsFactory) { - return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), consumerArn); + return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), consumerArn, validateRecordsAreForShard); } } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java index 4e8f69d1..a2ed3208 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java @@ -14,113 +14,88 @@ */ package software.amazon.kinesis.checkpoint; -//@RunWith(MockitoJUnitRunner.class) +import org.junit.Before; +import org.junit.Test; + +import java.util.Optional; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; +import static org.junit.Assert.assertThat; + public class SequenceNumberValidatorTest { - /*private final String streamName = "testStream"; - private final boolean validateWithGetIterator = true; - private final String shardId = "shardid-123"; - @Mock - private AmazonKinesis amazonKinesis; + private SequenceNumberValidator validator; - @Test (expected = IllegalArgumentException.class) - public final void testSequenceNumberValidator() { - Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName, - shardId, validateWithGetIterator); + @Before + public void begin() { + validator = new SequenceNumberValidator(); + } - String goodSequence = "456"; - String iterator = "happyiterator"; - String badSequence = "789"; - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetShardIteratorRequest.class); + @Test + public void matchingSequenceNumberTest() { + String sequenceNumber = "49587497311274533994574834252742144236107130636007899138"; + String expectedShardId = "shardId-000000000000"; - when(amazonKinesis.getShardIterator(requestCaptor.capture())) - .thenReturn(new GetShardIteratorResult().withShardIterator(iterator)) - .thenThrow(new InvalidArgumentException("")); + Optional version = validator.versionFor(sequenceNumber); + assertThat(version, equalTo(Optional.of(2))); - validator.validateSequenceNumber(goodSequence); - try { - validator.validateSequenceNumber(badSequence); - } finally { - final List requests = requestCaptor.getAllValues(); - assertEquals(2, requests.size()); + Optional shardId = validator.shardIdFor(sequenceNumber); + assertThat(shardId, equalTo(Optional.of(expectedShardId))); - final GetShardIteratorRequest goodRequest = requests.get(0); - final GetShardIteratorRequest badRequest = requests.get(0); - - assertEquals(streamName, goodRequest.getStreamName()); - assertEquals(shardId, goodRequest.shardId()); - assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), goodRequest.getShardIteratorType()); - assertEquals(goodSequence, goodRequest.getStartingSequenceNumber()); - - assertEquals(streamName, badRequest.getStreamName()); - assertEquals(shardId, badRequest.shardId()); - assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), badRequest.getShardIteratorType()); - assertEquals(goodSequence, badRequest.getStartingSequenceNumber()); - } + assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.of(true))); } @Test - public final void testNoValidation() { - Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName, - shardId, !validateWithGetIterator); - String sequenceNumber = "456"; + public void shardMismatchTest() { + String sequenceNumber = "49585389983312162443796657944872008114154899568972529698"; + String invalidShardId = "shardId-000000000001"; - // Just checking that the false flag for validating against getIterator is honored - validator.validateSequenceNumber(sequenceNumber); + Optional version = validator.versionFor(sequenceNumber); + assertThat(version, equalTo(Optional.of(2))); - verify(amazonKinesis, never()).getShardIterator(any(GetShardIteratorRequest.class)); + Optional shardId = validator.shardIdFor(sequenceNumber); + assertThat(shardId, not(equalTo(invalidShardId))); + + assertThat(validator.validateSequenceNumberForShard(sequenceNumber, invalidShardId), equalTo(Optional.of(false))); } @Test - public void nonNumericValueValidationTest() { - Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName, - shardId, validateWithGetIterator); + public void versionMismatchTest() { + String sequenceNumber = "74107425965128755728308386687147091174006956590945533954"; + String expectedShardId = "shardId-000000000000"; - String[] nonNumericStrings = {null, - "bogus-sequence-number", - SentinelCheckpoint.LATEST.toString(), - SentinelCheckpoint.TRIM_HORIZON.toString(), - SentinelCheckpoint.AT_TIMESTAMP.toString()}; + Optional version = validator.versionFor(sequenceNumber); + assertThat(version, equalTo(Optional.empty())); - Arrays.stream(nonNumericStrings).forEach(sequenceNumber -> { - try { - validator.validateSequenceNumber(sequenceNumber); - fail("Validator should not consider " + sequenceNumber + " a valid sequence number"); - } catch (IllegalArgumentException e) { - // Do nothing - } - }); + Optional shardId = validator.shardIdFor(sequenceNumber); + assertThat(shardId, equalTo(Optional.empty())); - verify(amazonKinesis, never()).getShardIterator(any(GetShardIteratorRequest.class)); + assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty())); } @Test - public final void testIsDigits() { - // Check things that are all digits - String[] stringsOfDigits = {"0", "12", "07897803434", "12324456576788"}; + public void sequenceNumberToShortTest() { + String sequenceNumber = "4958538998331216244379665794487200811415489956897252969"; + String expectedShardId = "shardId-000000000000"; + + assertThat(validator.versionFor(sequenceNumber), equalTo(Optional.empty())); + assertThat(validator.shardIdFor(sequenceNumber), equalTo(Optional.empty())); + + assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty())); + } + + @Test + public void sequenceNumberToLongTest() { + String sequenceNumber = "495874973112745339945748342527421442361071306360078991381"; + String expectedShardId = "shardId-000000000000"; + + assertThat(validator.versionFor(sequenceNumber), equalTo(Optional.empty())); + assertThat(validator.shardIdFor(sequenceNumber), equalTo(Optional.empty())); + + assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty())); + } + - for (String digits : stringsOfDigits) { - assertTrue("Expected that " + digits + " would be considered a string of digits.", - Checkpoint.SequenceNumberValidator.isDigits(digits)); - } - // Check things that are not all digits - String[] stringsWithNonDigits = { - null, - "", - " ", // white spaces - "6 4", - "\t45", - "5242354235234\n", - "7\n6\n5\n", - "12s", // last character - "c07897803434", // first character - "1232445wef6576788", // interior - "no-digits", - }; - for (String notAllDigits : stringsWithNonDigits) { - assertFalse("Expected that " + notAllDigits + " would not be considered a string of digits.", - Checkpoint.SequenceNumberValidator.isDigits(notAllDigits)); - } - }*/ } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java index 33045149..6284c6d5 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java @@ -1,6 +1,8 @@ package software.amazon.kinesis.retrieval.fanout; +import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; @@ -11,6 +13,7 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import java.math.BigInteger; import java.nio.ByteBuffer; import java.time.Instant; import java.util.ArrayList; @@ -50,8 +53,9 @@ import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; @Slf4j public class FanOutRecordsPublisherTest { - private static final String SHARD_ID = "Shard-001"; + private static final String SHARD_ID = "shardId-000000000001"; private static final String CONSUMER_ARN = "arn:consumer"; + private static final boolean VALIDATE_RECORD_SHARD_MATCHING = true; @Mock private KinesisAsyncClient kinesisClient; @@ -66,7 +70,8 @@ public class FanOutRecordsPublisherTest { @Test public void simpleTest() throws Exception { - FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); + FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN, + VALIDATE_RECORD_SHARD_MATCHING); ArgumentCaptor captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class); ArgumentCaptor flowCaptor = ArgumentCaptor @@ -133,7 +138,8 @@ public class FanOutRecordsPublisherTest { @Test public void largeRequestTest() throws Exception { - FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); + FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN, + VALIDATE_RECORD_SHARD_MATCHING); ArgumentCaptor captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class); ArgumentCaptor flowCaptor = ArgumentCaptor @@ -200,7 +206,8 @@ public class FanOutRecordsPublisherTest { @Test public void testResourceNotFoundForShard() { - FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); + FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN, + VALIDATE_RECORD_SHARD_MATCHING); ArgumentCaptor flowCaptor = ArgumentCaptor .forClass(FanOutRecordsPublisher.RecordFlow.class); @@ -224,7 +231,8 @@ public class FanOutRecordsPublisherTest { @Test public void testContinuesAfterSequence() { - FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); + FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN, + VALIDATE_RECORD_SHARD_MATCHING); ArgumentCaptor captor = ArgumentCaptor .forClass(FanOutRecordsPublisher.RecordSubscription.class); @@ -297,6 +305,66 @@ public class FanOutRecordsPublisherTest { } + @Test + public void mismatchedShardIdTest() { + FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN, true); + + ArgumentCaptor captor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordSubscription.class); + ArgumentCaptor flowCaptor = ArgumentCaptor + .forClass(FanOutRecordsPublisher.RecordFlow.class); + + doNothing().when(publisher).subscribe(captor.capture()); + + source.start(ExtendedSequenceNumber.LATEST, + InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST)); + + List errorsHandled = new ArrayList<>(); + List inputsReceived = new ArrayList<>(); + + source.subscribe(new Subscriber() { + Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + subscription = s; + subscription.request(1); + } + + @Override + public void onNext(ProcessRecordsInput input) { + inputsReceived.add(input); + } + + @Override + public void onError(Throwable t) { + errorsHandled.add(t); + + } + + @Override + public void onComplete() { + fail("OnComplete called when not expected"); + } + }); + + verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture()); + flowCaptor.getValue().onEventStream(publisher); + captor.getValue().onSubscribe(subscription); + + List records = Stream.of(1, 2, 3).map(seq -> makeRecord(seq, seq)).collect(Collectors.toList()); + + batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build(); + + captor.getValue().onNext(batchEvent); + + verify(subscription, times(1)).request(1); + assertThat(inputsReceived.size(), equalTo(0)); + assertThat(errorsHandled.size(), equalTo(1)); + assertThat(errorsHandled.get(0), instanceOf(IllegalArgumentException.class)); + assertThat(errorsHandled.get(0).getMessage(), containsString("Received records destined for different shards")); + } + private void verifyRecords(List clientRecordsList, List matchers) { assertThat(clientRecordsList.size(), equalTo(matchers.size())); for (int i = 0; i < clientRecordsList.size(); ++i) { @@ -333,9 +401,17 @@ public class FanOutRecordsPublisherTest { } private Record makeRecord(int sequenceNumber) { + return makeRecord(sequenceNumber, 1); + } + + private Record makeRecord(int sequenceNumber, int shardId) { + BigInteger version = BigInteger.valueOf(2).shiftLeft(184); + BigInteger shard = BigInteger.valueOf(shardId).shiftLeft(4); + BigInteger seq = version.add(shard).add(BigInteger.valueOf(sequenceNumber)); + SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 }); return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now()) - .sequenceNumber(Integer.toString(sequenceNumber)).partitionKey("A").build(); + .sequenceNumber(seq.toString()).partitionKey("A").build(); } private static class KinesisClientRecordMatcher extends TypeSafeDiagnosingMatcher {