From 01f5db8049f59c010d4cf6d87bb93de9877c35d9 Mon Sep 17 00:00:00 2001
From: Justin Pfifer
Date: Mon, 17 Sep 2018 14:33:46 -0700
Subject: [PATCH] Support Validating Records are From to the Expected Shard
(#400)
* SequenceNumberValidator for verifying shardId's
Added a SequenceNumberValidator that will can extract, and verify
shardId's from a v2 sequence number.
* Added documentation and bit length test
Added documentation for the public methods.
Added a bit length test for the reader that will reject sequence
numbers that don't fit the expectations.
* Added more comments and further document public operations
Added comments in the only SequenceNumberReader explaining how things
are expected to work.
Further documented the class and operations with expectations and outcomes.
* Added configuration to allow failing on mismatched records
Allow configuration which will cause the FanOutRecordsPublisher to
throw an exception when it detects records that aren't for the shard
it's processing.
---
.../checkpoint/SequenceNumberValidator.java | 188 ++++++++++++++++++
.../retrieval/fanout/FanOutConfig.java | 8 +-
.../fanout/FanOutRecordsPublisher.java | 94 ++++++++-
.../fanout/FanOutRetrievalFactory.java | 9 +-
.../SequenceNumberValidatorTest.java | 147 ++++++--------
.../fanout/FanOutRecordsPublisherTest.java | 88 +++++++-
6 files changed, 435 insertions(+), 99 deletions(-)
create mode 100644 amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java
new file mode 100644
index 00000000..e18da9ec
--- /dev/null
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java
@@ -0,0 +1,188 @@
+/*
+ * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Amazon Software License (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/asl/
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.kinesis.checkpoint;
+
+import java.math.BigInteger;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+import org.apache.commons.lang3.StringUtils;
+
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+/**
+ * This supports extracting the shardId from a sequence number.
+ *
+ * Warning
+ * Sequence numbers are an opaque value used by Kinesis, and maybe changed at any time. Should validation stop
+ * working you may need to update your version of the KCL
+ *
+ */
+public class SequenceNumberValidator {
+
+ @Data
+ @Accessors(fluent = true)
+ private static class SequenceNumberComponents {
+ final int version;
+ final int shardId;
+ }
+
+ private interface SequenceNumberReader {
+ Optional read(String sequenceNumber);
+ }
+
+ /**
+ * Reader for the v2 sequence number format. v1 sequence numbers are no longer used or available.
+ */
+ private static class V2SequenceNumberReader implements SequenceNumberReader {
+
+ private static final int VERSION = 2;
+
+ private static final int EXPECTED_BIT_LENGTH = 186;
+
+ private static final int VERSION_OFFSET = 184;
+ private static final long VERSION_MASK = (1 << 4) - 1;
+
+ private static final int SHARD_ID_OFFSET = 4;
+ private static final long SHARD_ID_MASK = (1L << 32) - 1;
+
+ @Override
+ public Optional read(String sequenceNumberString) {
+ BigInteger sequenceNumber = new BigInteger(sequenceNumberString, 10);
+
+ //
+ // If the bit length of the sequence number isn't 186 it's impossible for the version numbers
+ // to be where we expect them. We treat this the same as an unknown version of the sequence number
+ //
+ // If the sequence number length isn't what we expect it's due to a new version of the sequence number or
+ // an invalid sequence number. This
+ //
+ if (sequenceNumber.bitLength() != EXPECTED_BIT_LENGTH) {
+ return Optional.empty();
+ }
+
+ //
+ // Read the 4 most significant bits of the sequence number, the 2 most significant bits are implicitly 0
+ // (2 == 0b0011). If the version number doesn't match we give up and say we can't parse the sequence number
+ //
+ int version = readOffset(sequenceNumber, VERSION_OFFSET, VERSION_MASK);
+ if (version != VERSION) {
+ return Optional.empty();
+ }
+
+ //
+ // If we get here the sequence number is big enough, and the version matches so the shardId should be valid.
+ //
+ int shardId = readOffset(sequenceNumber, SHARD_ID_OFFSET, SHARD_ID_MASK);
+ return Optional.of(new SequenceNumberComponents(version, shardId));
+ }
+
+ private int readOffset(BigInteger sequenceNumber, int offset, long mask) {
+ long value = sequenceNumber.shiftRight(offset).longValue() & mask;
+ return (int) value;
+ }
+ }
+
+ private static final List SEQUENCE_NUMBER_READERS = Collections
+ .singletonList(new V2SequenceNumberReader());
+
+ private Optional retrieveComponentsFor(String sequenceNumber) {
+ return SEQUENCE_NUMBER_READERS.stream().map(r -> r.read(sequenceNumber)).filter(Optional::isPresent).map(Optional::get).findFirst();
+ }
+
+ /**
+ * Attempts to retrieve the version for a sequence number. If no reader can be found for the sequence number this
+ * will return an empty Optional.
+ *
+ *
+ * This will return an empty Optional if the it's unable to extract the version number. This can occur for
+ * multiple reasons including:
+ *
+ * - Kinesis has started using a new version of sequence numbers
+ * - The provided sequence number isn't a valid Kinesis sequence number.
+ *
+ *
+ *
+ *
+ * @param sequenceNumber
+ * the sequence number to extract the version from
+ * @return an Optional containing the version if a compatible sequence number reader can be found, an empty Optional
+ * otherwise.
+ */
+ public Optional versionFor(String sequenceNumber) {
+ return retrieveComponentsFor(sequenceNumber).map(SequenceNumberComponents::version);
+ }
+
+ /**
+ * Attempts to retrieve the shardId from a sequence number. If the version of the sequence number is unsupported
+ * this will return an empty optional.
+ *
+ * This will return an empty Optional if the sequence number isn't recognized. This can occur for multiple
+ * reasons including:
+ *
+ * - Kinesis has started using a new version of sequence numbers
+ * - The provided sequence number isn't a valid Kinesis sequence number.
+ *
+ *
+ *
+ * This should always return a value if {@link #versionFor(String)} returns a value
+ *
+ *
+ * @param sequenceNumber
+ * the sequence number to extract the shardId from
+ * @return an Optional containing the shardId if the version is supported, an empty Optional otherwise.
+ */
+ public Optional shardIdFor(String sequenceNumber) {
+ return retrieveComponentsFor(sequenceNumber).map(s -> String.format("shardId-%012d", s.shardId()));
+ }
+
+ /**
+ * Validates that the sequence number provided contains the given shardId. If the sequence number is unsupported
+ * this will return an empty Optional.
+ *
+ *
+ * Validation of a sequence number will only occur if the sequence number can be parsed. It's possible to use
+ * {@link #versionFor(String)} to verify that the given sequence number is supported by this class. There are 3
+ * possible validation states:
+ *
+ * - Some(True)
+ * - The sequence number can be parsed, and the shardId matches the one in the sequence number
+ * - Some(False)
+ * - THe sequence number can be parsed, and the shardId doesn't match the one in the sequence number
+ * - None
+ * - It wasn't possible to parse the sequence number so the validity of the sequence number is unknown
+ *
+ *
+ *
+ *
+ * Handling unknown validation causes is application specific, and not specific handling is
+ * provided.
+ *
+ *
+ * @param sequenceNumber
+ * the sequence number to verify the shardId
+ * @param shardId
+ * the shardId that the sequence is expected to contain
+ * @return true if the sequence number contains the shardId, false if it doesn't. If the sequence number version is
+ * unsupported this will return an empty Optional
+ */
+ public Optional validateSequenceNumberForShard(String sequenceNumber, String shardId) {
+ return shardIdFor(sequenceNumber).map(s -> StringUtils.equalsIgnoreCase(s, shardId));
+ }
+
+}
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
index 33f519f9..173a871e 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
@@ -80,9 +80,15 @@ public class FanOutConfig implements RetrievalSpecificConfig {
*/
private long retryBackoffMillis = 1000;
+ /**
+ * Controls whether the {@link FanOutRecordsPublisher} will validate that all the records are from the shard it's
+ * processing.
+ */
+ private boolean validateRecordsAreForShard = false;
+
@Override
public RetrievalFactory retrievalFactory() {
- return new FanOutRetrievalFactory(kinesisClient, getOrCreateConsumerArn());
+ return new FanOutRetrievalFactory(kinesisClient, getOrCreateConsumerArn()).validateRecordsAreForShard(validateRecordsAreForShard);
}
private String getOrCreateConsumerArn() {
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java
index 728aafc0..bb9c0c75 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java
@@ -18,15 +18,20 @@ package software.amazon.kinesis.retrieval.fanout;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
+import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
import java.util.stream.Collectors;
+import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.core.async.SdkPublisher;
+import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
@@ -35,6 +40,7 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
+import software.amazon.kinesis.checkpoint.SequenceNumberValidator;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.KinesisRequestsBuilder;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
@@ -43,7 +49,6 @@ import software.amazon.kinesis.retrieval.KinesisClientRecord;
import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
-@RequiredArgsConstructor
@Slf4j
@KinesisClientInternalApi
public class FanOutRecordsPublisher implements RecordsPublisher {
@@ -51,6 +56,34 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
private final KinesisAsyncClient kinesis;
private final String shardId;
private final String consumerArn;
+ private final boolean validateRecordShardMatching;
+
+ /**
+ * Creates a new FanOutRecordsPublisher.
+ *
+ * This is deprecated and will be removed in a later release. Use
+ * {@link #FanOutRecordsPublisher(KinesisAsyncClient, String, String, boolean)} instead
+ *
+ *
+ * @param kinesis
+ * the kinesis client to use for requests
+ * @param shardId
+ * the shardId to retrieve records for
+ * @param consumerArn
+ * the consumer to use when retrieving records
+ */
+ @Deprecated
+ public FanOutRecordsPublisher(KinesisAsyncClient kinesis, String shardId, String consumerArn) {
+ this(kinesis, shardId, consumerArn, false);
+ }
+
+ public FanOutRecordsPublisher(KinesisAsyncClient kinesis, String shardId, String consumerArn,
+ boolean validateRecordShardMatching) {
+ this.kinesis = kinesis;
+ this.shardId = shardId;
+ this.consumerArn = consumerArn;
+ this.validateRecordShardMatching = validateRecordShardMatching;
+ }
private final Object lockObject = new Object();
@@ -451,8 +484,15 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
@Override
public void responseReceived(SubscribeToShardResponse response) {
- log.debug("{}: [SubscriptionLifetime]: (RecordFlow#responseReceived) @ {} id: {} -- Response received",
- parent.shardId, connectionStartedAt, subscribeToShardId);
+ Optional sdkHttpResponse = Optional.ofNullable(response)
+ .flatMap(r -> Optional.ofNullable(r.sdkHttpResponse()));
+ Optional requestId = sdkHttpResponse.flatMap(s -> s.firstMatchingHeader("x-amz-requestid"));
+ Optional requestId2 = sdkHttpResponse.flatMap(s -> s.firstMatchingHeader("x-amz-id-2"));
+
+ log.debug(
+ "{}: [SubscriptionLifetime]: (RecordFlow#responseReceived) @ {} id: {} -- Response received -- rid: {} -- rid2: {}",
+ parent.shardId, connectionStartedAt, subscribeToShardId, requestId.orElse("None"),
+ requestId2.orElse("None"));
}
@Override
@@ -548,6 +588,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
private final RecordFlow flow;
private final Instant connectionStartedAt;
private final String subscribeToShardId;
+ private final SequenceNumberValidator sequenceNumberValidator = new SequenceNumberValidator();
private Subscription subscription;
@@ -594,8 +635,9 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
cancel();
}
log.debug(
- "{}: [SubscriptionLifetime]: (RecordSubscription#onSubscribe) @ {} id: {} -- Outstanding: {} items so requesting an item",
- parent.shardId, connectionStartedAt, subscribeToShardId, parent.availableQueueSpace);
+ "{}: [SubscriptionLifetime]: (RecordSubscription#onSubscribe) @ {} id: {} (Subscription ObjectId: {}) -- Outstanding: {} items so requesting an item",
+ parent.shardId, connectionStartedAt, subscribeToShardId, System.identityHashCode(subscription),
+ parent.availableQueueSpace);
if (parent.availableQueueSpace > 0) {
request(1);
}
@@ -615,12 +657,54 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
recordBatchEvent.accept(new SubscribeToShardResponseHandler.Visitor() {
@Override
public void visit(SubscribeToShardEvent event) {
+ if (parent.validateRecordShardMatching && !areRecordsValid(event)) {
+ return;
+ }
flow.recordsReceived(event);
}
});
}
}
+ private boolean areRecordsValid(SubscribeToShardEvent event) {
+ try {
+ Map mismatchedRecords = recordsNotForShard(event);
+ if (mismatchedRecords.size() > 0) {
+ String mismatchReport = mismatchedRecords.entrySet().stream()
+ .map(e -> String.format("(%s -> %d)", e.getKey(), e.getValue()))
+ .collect(Collectors.joining(", "));
+ log.debug(
+ "{}: [SubscriptionLifetime]: (RecordSubscription#onNext#vistor) @ {} id: {} (Subscription ObjectId: {}) -- Failing subscription due to mismatches: [ {} ]",
+ parent.shardId, connectionStartedAt, subscribeToShardId,
+ System.identityHashCode(subscription), mismatchReport);
+ parent.errorOccurred(flow, new IllegalArgumentException(
+ "Received records destined for different shards: " + mismatchReport));
+ return false;
+ }
+ } catch (IllegalArgumentException iae) {
+ log.debug(
+ "{}: [SubscriptionLifetime]: (RecordSubscription#onNext#vistor) @ {} id: {} (Subscription ObjectId: {}) -- "
+ + "A problem occurred while validating sequence numbers: {} on subscription {}",
+ parent.shardId, connectionStartedAt, subscribeToShardId, System.identityHashCode(subscription),
+ iae.getMessage(), iae);
+ parent.errorOccurred(flow, iae);
+ return false;
+ }
+ return true;
+ }
+
+ private Map recordsNotForShard(SubscribeToShardEvent event) {
+ return event.records().stream().map(r -> {
+ Optional res = sequenceNumberValidator.shardIdFor(r.sequenceNumber());
+ if (!res.isPresent()) {
+ throw new IllegalArgumentException("Unable to validate sequence number of " + r.sequenceNumber());
+ }
+ return res.get();
+ }).filter(s -> !StringUtils.equalsIgnoreCase(s, parent.shardId))
+ .collect(Collectors.groupingBy(Function.identity())).entrySet().stream()
+ .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().size()));
+ }
+
@Override
public void onError(Throwable t) {
log.debug("{}: [SubscriptionLifetime]: (RecordSubscription#onError) @ {} id: {} -- {}: {}", parent.shardId,
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java
index eea61250..46be077f 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java
@@ -15,8 +15,11 @@
package software.amazon.kinesis.retrieval.fanout;
+import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
+import lombok.Setter;
+import lombok.experimental.Accessors;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.leases.ShardInfo;
@@ -27,10 +30,14 @@ import software.amazon.kinesis.retrieval.RetrievalFactory;
@RequiredArgsConstructor
@KinesisClientInternalApi
+@Accessors(fluent = true)
public class FanOutRetrievalFactory implements RetrievalFactory {
private final KinesisAsyncClient kinesisClient;
private final String consumerArn;
+ @Getter
+ @Setter
+ private boolean validateRecordsAreForShard = false;
@Override
public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo,
@@ -41,6 +48,6 @@ public class FanOutRetrievalFactory implements RetrievalFactory {
@Override
public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo,
final MetricsFactory metricsFactory) {
- return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), consumerArn);
+ return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), consumerArn, validateRecordsAreForShard);
}
}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java
index 4e8f69d1..a2ed3208 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java
@@ -14,113 +14,88 @@
*/
package software.amazon.kinesis.checkpoint;
-//@RunWith(MockitoJUnitRunner.class)
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Optional;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.not;
+import static org.junit.Assert.assertThat;
+
public class SequenceNumberValidatorTest {
- /*private final String streamName = "testStream";
- private final boolean validateWithGetIterator = true;
- private final String shardId = "shardid-123";
- @Mock
- private AmazonKinesis amazonKinesis;
+ private SequenceNumberValidator validator;
- @Test (expected = IllegalArgumentException.class)
- public final void testSequenceNumberValidator() {
- Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName,
- shardId, validateWithGetIterator);
+ @Before
+ public void begin() {
+ validator = new SequenceNumberValidator();
+ }
- String goodSequence = "456";
- String iterator = "happyiterator";
- String badSequence = "789";
- ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetShardIteratorRequest.class);
+ @Test
+ public void matchingSequenceNumberTest() {
+ String sequenceNumber = "49587497311274533994574834252742144236107130636007899138";
+ String expectedShardId = "shardId-000000000000";
- when(amazonKinesis.getShardIterator(requestCaptor.capture()))
- .thenReturn(new GetShardIteratorResult().withShardIterator(iterator))
- .thenThrow(new InvalidArgumentException(""));
+ Optional version = validator.versionFor(sequenceNumber);
+ assertThat(version, equalTo(Optional.of(2)));
- validator.validateSequenceNumber(goodSequence);
- try {
- validator.validateSequenceNumber(badSequence);
- } finally {
- final List requests = requestCaptor.getAllValues();
- assertEquals(2, requests.size());
+ Optional shardId = validator.shardIdFor(sequenceNumber);
+ assertThat(shardId, equalTo(Optional.of(expectedShardId)));
- final GetShardIteratorRequest goodRequest = requests.get(0);
- final GetShardIteratorRequest badRequest = requests.get(0);
-
- assertEquals(streamName, goodRequest.getStreamName());
- assertEquals(shardId, goodRequest.shardId());
- assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), goodRequest.getShardIteratorType());
- assertEquals(goodSequence, goodRequest.getStartingSequenceNumber());
-
- assertEquals(streamName, badRequest.getStreamName());
- assertEquals(shardId, badRequest.shardId());
- assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), badRequest.getShardIteratorType());
- assertEquals(goodSequence, badRequest.getStartingSequenceNumber());
- }
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.of(true)));
}
@Test
- public final void testNoValidation() {
- Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName,
- shardId, !validateWithGetIterator);
- String sequenceNumber = "456";
+ public void shardMismatchTest() {
+ String sequenceNumber = "49585389983312162443796657944872008114154899568972529698";
+ String invalidShardId = "shardId-000000000001";
- // Just checking that the false flag for validating against getIterator is honored
- validator.validateSequenceNumber(sequenceNumber);
+ Optional version = validator.versionFor(sequenceNumber);
+ assertThat(version, equalTo(Optional.of(2)));
- verify(amazonKinesis, never()).getShardIterator(any(GetShardIteratorRequest.class));
+ Optional shardId = validator.shardIdFor(sequenceNumber);
+ assertThat(shardId, not(equalTo(invalidShardId)));
+
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, invalidShardId), equalTo(Optional.of(false)));
}
@Test
- public void nonNumericValueValidationTest() {
- Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName,
- shardId, validateWithGetIterator);
+ public void versionMismatchTest() {
+ String sequenceNumber = "74107425965128755728308386687147091174006956590945533954";
+ String expectedShardId = "shardId-000000000000";
- String[] nonNumericStrings = {null,
- "bogus-sequence-number",
- SentinelCheckpoint.LATEST.toString(),
- SentinelCheckpoint.TRIM_HORIZON.toString(),
- SentinelCheckpoint.AT_TIMESTAMP.toString()};
+ Optional version = validator.versionFor(sequenceNumber);
+ assertThat(version, equalTo(Optional.empty()));
- Arrays.stream(nonNumericStrings).forEach(sequenceNumber -> {
- try {
- validator.validateSequenceNumber(sequenceNumber);
- fail("Validator should not consider " + sequenceNumber + " a valid sequence number");
- } catch (IllegalArgumentException e) {
- // Do nothing
- }
- });
+ Optional shardId = validator.shardIdFor(sequenceNumber);
+ assertThat(shardId, equalTo(Optional.empty()));
- verify(amazonKinesis, never()).getShardIterator(any(GetShardIteratorRequest.class));
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
}
@Test
- public final void testIsDigits() {
- // Check things that are all digits
- String[] stringsOfDigits = {"0", "12", "07897803434", "12324456576788"};
+ public void sequenceNumberToShortTest() {
+ String sequenceNumber = "4958538998331216244379665794487200811415489956897252969";
+ String expectedShardId = "shardId-000000000000";
+
+ assertThat(validator.versionFor(sequenceNumber), equalTo(Optional.empty()));
+ assertThat(validator.shardIdFor(sequenceNumber), equalTo(Optional.empty()));
+
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
+ }
+
+ @Test
+ public void sequenceNumberToLongTest() {
+ String sequenceNumber = "495874973112745339945748342527421442361071306360078991381";
+ String expectedShardId = "shardId-000000000000";
+
+ assertThat(validator.versionFor(sequenceNumber), equalTo(Optional.empty()));
+ assertThat(validator.shardIdFor(sequenceNumber), equalTo(Optional.empty()));
+
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
+ }
+
- for (String digits : stringsOfDigits) {
- assertTrue("Expected that " + digits + " would be considered a string of digits.",
- Checkpoint.SequenceNumberValidator.isDigits(digits));
- }
- // Check things that are not all digits
- String[] stringsWithNonDigits = {
- null,
- "",
- " ", // white spaces
- "6 4",
- "\t45",
- "5242354235234\n",
- "7\n6\n5\n",
- "12s", // last character
- "c07897803434", // first character
- "1232445wef6576788", // interior
- "no-digits",
- };
- for (String notAllDigits : stringsWithNonDigits) {
- assertFalse("Expected that " + notAllDigits + " would not be considered a string of digits.",
- Checkpoint.SequenceNumberValidator.isDigits(notAllDigits));
- }
- }*/
}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java
index 33045149..6284c6d5 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java
@@ -1,6 +1,8 @@
package software.amazon.kinesis.retrieval.fanout;
+import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
@@ -11,6 +13,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.ArrayList;
@@ -50,8 +53,9 @@ import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
@Slf4j
public class FanOutRecordsPublisherTest {
- private static final String SHARD_ID = "Shard-001";
+ private static final String SHARD_ID = "shardId-000000000001";
private static final String CONSUMER_ARN = "arn:consumer";
+ private static final boolean VALIDATE_RECORD_SHARD_MATCHING = true;
@Mock
private KinesisAsyncClient kinesisClient;
@@ -66,7 +70,8 @@ public class FanOutRecordsPublisherTest {
@Test
public void simpleTest() throws Exception {
- FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
+ VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor flowCaptor = ArgumentCaptor
@@ -133,7 +138,8 @@ public class FanOutRecordsPublisherTest {
@Test
public void largeRequestTest() throws Exception {
- FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
+ VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor flowCaptor = ArgumentCaptor
@@ -200,7 +206,8 @@ public class FanOutRecordsPublisherTest {
@Test
public void testResourceNotFoundForShard() {
- FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
+ VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor flowCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordFlow.class);
@@ -224,7 +231,8 @@ public class FanOutRecordsPublisherTest {
@Test
public void testContinuesAfterSequence() {
- FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
+ VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor captor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordSubscription.class);
@@ -297,6 +305,66 @@ public class FanOutRecordsPublisherTest {
}
+ @Test
+ public void mismatchedShardIdTest() {
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN, true);
+
+ ArgumentCaptor captor = ArgumentCaptor
+ .forClass(FanOutRecordsPublisher.RecordSubscription.class);
+ ArgumentCaptor flowCaptor = ArgumentCaptor
+ .forClass(FanOutRecordsPublisher.RecordFlow.class);
+
+ doNothing().when(publisher).subscribe(captor.capture());
+
+ source.start(ExtendedSequenceNumber.LATEST,
+ InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));
+
+ List errorsHandled = new ArrayList<>();
+ List inputsReceived = new ArrayList<>();
+
+ source.subscribe(new Subscriber() {
+ Subscription subscription;
+
+ @Override
+ public void onSubscribe(Subscription s) {
+ subscription = s;
+ subscription.request(1);
+ }
+
+ @Override
+ public void onNext(ProcessRecordsInput input) {
+ inputsReceived.add(input);
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ errorsHandled.add(t);
+
+ }
+
+ @Override
+ public void onComplete() {
+ fail("OnComplete called when not expected");
+ }
+ });
+
+ verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
+ flowCaptor.getValue().onEventStream(publisher);
+ captor.getValue().onSubscribe(subscription);
+
+ List records = Stream.of(1, 2, 3).map(seq -> makeRecord(seq, seq)).collect(Collectors.toList());
+
+ batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build();
+
+ captor.getValue().onNext(batchEvent);
+
+ verify(subscription, times(1)).request(1);
+ assertThat(inputsReceived.size(), equalTo(0));
+ assertThat(errorsHandled.size(), equalTo(1));
+ assertThat(errorsHandled.get(0), instanceOf(IllegalArgumentException.class));
+ assertThat(errorsHandled.get(0).getMessage(), containsString("Received records destined for different shards"));
+ }
+
private void verifyRecords(List clientRecordsList, List matchers) {
assertThat(clientRecordsList.size(), equalTo(matchers.size()));
for (int i = 0; i < clientRecordsList.size(); ++i) {
@@ -333,9 +401,17 @@ public class FanOutRecordsPublisherTest {
}
private Record makeRecord(int sequenceNumber) {
+ return makeRecord(sequenceNumber, 1);
+ }
+
+ private Record makeRecord(int sequenceNumber, int shardId) {
+ BigInteger version = BigInteger.valueOf(2).shiftLeft(184);
+ BigInteger shard = BigInteger.valueOf(shardId).shiftLeft(4);
+ BigInteger seq = version.add(shard).add(BigInteger.valueOf(sequenceNumber));
+
SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 });
return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now())
- .sequenceNumber(Integer.toString(sequenceNumber)).partitionKey("A").build();
+ .sequenceNumber(seq.toString()).partitionKey("A").build();
}
private static class KinesisClientRecordMatcher extends TypeSafeDiagnosingMatcher {