diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java
new file mode 100644
index 00000000..e18da9ec
--- /dev/null
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/SequenceNumberValidator.java
@@ -0,0 +1,188 @@
+/*
+ * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Amazon Software License (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/asl/
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.kinesis.checkpoint;
+
+import java.math.BigInteger;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+import org.apache.commons.lang3.StringUtils;
+
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+/**
+ * This supports extracting the shardId from a sequence number.
+ *
+ *
Warning
+ * Sequence numbers are an opaque value used by Kinesis, and maybe changed at any time. Should validation stop
+ * working you may need to update your version of the KCL
+ *
+ */
+public class SequenceNumberValidator {
+
+ @Data
+ @Accessors(fluent = true)
+ private static class SequenceNumberComponents {
+ final int version;
+ final int shardId;
+ }
+
+ private interface SequenceNumberReader {
+ Optional read(String sequenceNumber);
+ }
+
+ /**
+ * Reader for the v2 sequence number format. v1 sequence numbers are no longer used or available.
+ */
+ private static class V2SequenceNumberReader implements SequenceNumberReader {
+
+ private static final int VERSION = 2;
+
+ private static final int EXPECTED_BIT_LENGTH = 186;
+
+ private static final int VERSION_OFFSET = 184;
+ private static final long VERSION_MASK = (1 << 4) - 1;
+
+ private static final int SHARD_ID_OFFSET = 4;
+ private static final long SHARD_ID_MASK = (1L << 32) - 1;
+
+ @Override
+ public Optional read(String sequenceNumberString) {
+ BigInteger sequenceNumber = new BigInteger(sequenceNumberString, 10);
+
+ //
+ // If the bit length of the sequence number isn't 186 it's impossible for the version numbers
+ // to be where we expect them. We treat this the same as an unknown version of the sequence number
+ //
+ // If the sequence number length isn't what we expect it's due to a new version of the sequence number or
+ // an invalid sequence number. This
+ //
+ if (sequenceNumber.bitLength() != EXPECTED_BIT_LENGTH) {
+ return Optional.empty();
+ }
+
+ //
+ // Read the 4 most significant bits of the sequence number, the 2 most significant bits are implicitly 0
+ // (2 == 0b0011). If the version number doesn't match we give up and say we can't parse the sequence number
+ //
+ int version = readOffset(sequenceNumber, VERSION_OFFSET, VERSION_MASK);
+ if (version != VERSION) {
+ return Optional.empty();
+ }
+
+ //
+ // If we get here the sequence number is big enough, and the version matches so the shardId should be valid.
+ //
+ int shardId = readOffset(sequenceNumber, SHARD_ID_OFFSET, SHARD_ID_MASK);
+ return Optional.of(new SequenceNumberComponents(version, shardId));
+ }
+
+ private int readOffset(BigInteger sequenceNumber, int offset, long mask) {
+ long value = sequenceNumber.shiftRight(offset).longValue() & mask;
+ return (int) value;
+ }
+ }
+
+ private static final List SEQUENCE_NUMBER_READERS = Collections
+ .singletonList(new V2SequenceNumberReader());
+
+ private Optional retrieveComponentsFor(String sequenceNumber) {
+ return SEQUENCE_NUMBER_READERS.stream().map(r -> r.read(sequenceNumber)).filter(Optional::isPresent).map(Optional::get).findFirst();
+ }
+
+ /**
+ * Attempts to retrieve the version for a sequence number. If no reader can be found for the sequence number this
+ * will return an empty Optional.
+ *
+ *
+ * This will return an empty Optional if the it's unable to extract the version number. This can occur for
+ * multiple reasons including:
+ *
+ * - Kinesis has started using a new version of sequence numbers
+ * - The provided sequence number isn't a valid Kinesis sequence number.
+ *
+ *
+ *
+ *
+ * @param sequenceNumber
+ * the sequence number to extract the version from
+ * @return an Optional containing the version if a compatible sequence number reader can be found, an empty Optional
+ * otherwise.
+ */
+ public Optional versionFor(String sequenceNumber) {
+ return retrieveComponentsFor(sequenceNumber).map(SequenceNumberComponents::version);
+ }
+
+ /**
+ * Attempts to retrieve the shardId from a sequence number. If the version of the sequence number is unsupported
+ * this will return an empty optional.
+ *
+ * This will return an empty Optional if the sequence number isn't recognized. This can occur for multiple
+ * reasons including:
+ *
+ * - Kinesis has started using a new version of sequence numbers
+ * - The provided sequence number isn't a valid Kinesis sequence number.
+ *
+ *
+ *
+ * This should always return a value if {@link #versionFor(String)} returns a value
+ *
+ *
+ * @param sequenceNumber
+ * the sequence number to extract the shardId from
+ * @return an Optional containing the shardId if the version is supported, an empty Optional otherwise.
+ */
+ public Optional shardIdFor(String sequenceNumber) {
+ return retrieveComponentsFor(sequenceNumber).map(s -> String.format("shardId-%012d", s.shardId()));
+ }
+
+ /**
+ * Validates that the sequence number provided contains the given shardId. If the sequence number is unsupported
+ * this will return an empty Optional.
+ *
+ *
+ * Validation of a sequence number will only occur if the sequence number can be parsed. It's possible to use
+ * {@link #versionFor(String)} to verify that the given sequence number is supported by this class. There are 3
+ * possible validation states:
+ *
+ * - Some(True)
+ * - The sequence number can be parsed, and the shardId matches the one in the sequence number
+ * - Some(False)
+ * - THe sequence number can be parsed, and the shardId doesn't match the one in the sequence number
+ * - None
+ * - It wasn't possible to parse the sequence number so the validity of the sequence number is unknown
+ *
+ *
+ *
+ *
+ * Handling unknown validation causes is application specific, and not specific handling is
+ * provided.
+ *
+ *
+ * @param sequenceNumber
+ * the sequence number to verify the shardId
+ * @param shardId
+ * the shardId that the sequence is expected to contain
+ * @return true if the sequence number contains the shardId, false if it doesn't. If the sequence number version is
+ * unsupported this will return an empty Optional
+ */
+ public Optional validateSequenceNumberForShard(String sequenceNumber, String shardId) {
+ return shardIdFor(sequenceNumber).map(s -> StringUtils.equalsIgnoreCase(s, shardId));
+ }
+
+}
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
index 33f519f9..173a871e 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java
@@ -80,9 +80,15 @@ public class FanOutConfig implements RetrievalSpecificConfig {
*/
private long retryBackoffMillis = 1000;
+ /**
+ * Controls whether the {@link FanOutRecordsPublisher} will validate that all the records are from the shard it's
+ * processing.
+ */
+ private boolean validateRecordsAreForShard = false;
+
@Override
public RetrievalFactory retrievalFactory() {
- return new FanOutRetrievalFactory(kinesisClient, getOrCreateConsumerArn());
+ return new FanOutRetrievalFactory(kinesisClient, getOrCreateConsumerArn()).validateRecordsAreForShard(validateRecordsAreForShard);
}
private String getOrCreateConsumerArn() {
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java
index 728aafc0..bb9c0c75 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java
@@ -18,15 +18,20 @@ package software.amazon.kinesis.retrieval.fanout;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
+import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
import java.util.stream.Collectors;
+import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.core.async.SdkPublisher;
+import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
@@ -35,6 +40,7 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
+import software.amazon.kinesis.checkpoint.SequenceNumberValidator;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.KinesisRequestsBuilder;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
@@ -43,7 +49,6 @@ import software.amazon.kinesis.retrieval.KinesisClientRecord;
import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
-@RequiredArgsConstructor
@Slf4j
@KinesisClientInternalApi
public class FanOutRecordsPublisher implements RecordsPublisher {
@@ -51,6 +56,34 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
private final KinesisAsyncClient kinesis;
private final String shardId;
private final String consumerArn;
+ private final boolean validateRecordShardMatching;
+
+ /**
+ * Creates a new FanOutRecordsPublisher.
+ *
+ * This is deprecated and will be removed in a later release. Use
+ * {@link #FanOutRecordsPublisher(KinesisAsyncClient, String, String, boolean)} instead
+ *
+ *
+ * @param kinesis
+ * the kinesis client to use for requests
+ * @param shardId
+ * the shardId to retrieve records for
+ * @param consumerArn
+ * the consumer to use when retrieving records
+ */
+ @Deprecated
+ public FanOutRecordsPublisher(KinesisAsyncClient kinesis, String shardId, String consumerArn) {
+ this(kinesis, shardId, consumerArn, false);
+ }
+
+ public FanOutRecordsPublisher(KinesisAsyncClient kinesis, String shardId, String consumerArn,
+ boolean validateRecordShardMatching) {
+ this.kinesis = kinesis;
+ this.shardId = shardId;
+ this.consumerArn = consumerArn;
+ this.validateRecordShardMatching = validateRecordShardMatching;
+ }
private final Object lockObject = new Object();
@@ -451,8 +484,15 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
@Override
public void responseReceived(SubscribeToShardResponse response) {
- log.debug("{}: [SubscriptionLifetime]: (RecordFlow#responseReceived) @ {} id: {} -- Response received",
- parent.shardId, connectionStartedAt, subscribeToShardId);
+ Optional sdkHttpResponse = Optional.ofNullable(response)
+ .flatMap(r -> Optional.ofNullable(r.sdkHttpResponse()));
+ Optional requestId = sdkHttpResponse.flatMap(s -> s.firstMatchingHeader("x-amz-requestid"));
+ Optional requestId2 = sdkHttpResponse.flatMap(s -> s.firstMatchingHeader("x-amz-id-2"));
+
+ log.debug(
+ "{}: [SubscriptionLifetime]: (RecordFlow#responseReceived) @ {} id: {} -- Response received -- rid: {} -- rid2: {}",
+ parent.shardId, connectionStartedAt, subscribeToShardId, requestId.orElse("None"),
+ requestId2.orElse("None"));
}
@Override
@@ -548,6 +588,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
private final RecordFlow flow;
private final Instant connectionStartedAt;
private final String subscribeToShardId;
+ private final SequenceNumberValidator sequenceNumberValidator = new SequenceNumberValidator();
private Subscription subscription;
@@ -594,8 +635,9 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
cancel();
}
log.debug(
- "{}: [SubscriptionLifetime]: (RecordSubscription#onSubscribe) @ {} id: {} -- Outstanding: {} items so requesting an item",
- parent.shardId, connectionStartedAt, subscribeToShardId, parent.availableQueueSpace);
+ "{}: [SubscriptionLifetime]: (RecordSubscription#onSubscribe) @ {} id: {} (Subscription ObjectId: {}) -- Outstanding: {} items so requesting an item",
+ parent.shardId, connectionStartedAt, subscribeToShardId, System.identityHashCode(subscription),
+ parent.availableQueueSpace);
if (parent.availableQueueSpace > 0) {
request(1);
}
@@ -615,12 +657,54 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
recordBatchEvent.accept(new SubscribeToShardResponseHandler.Visitor() {
@Override
public void visit(SubscribeToShardEvent event) {
+ if (parent.validateRecordShardMatching && !areRecordsValid(event)) {
+ return;
+ }
flow.recordsReceived(event);
}
});
}
}
+ private boolean areRecordsValid(SubscribeToShardEvent event) {
+ try {
+ Map mismatchedRecords = recordsNotForShard(event);
+ if (mismatchedRecords.size() > 0) {
+ String mismatchReport = mismatchedRecords.entrySet().stream()
+ .map(e -> String.format("(%s -> %d)", e.getKey(), e.getValue()))
+ .collect(Collectors.joining(", "));
+ log.debug(
+ "{}: [SubscriptionLifetime]: (RecordSubscription#onNext#vistor) @ {} id: {} (Subscription ObjectId: {}) -- Failing subscription due to mismatches: [ {} ]",
+ parent.shardId, connectionStartedAt, subscribeToShardId,
+ System.identityHashCode(subscription), mismatchReport);
+ parent.errorOccurred(flow, new IllegalArgumentException(
+ "Received records destined for different shards: " + mismatchReport));
+ return false;
+ }
+ } catch (IllegalArgumentException iae) {
+ log.debug(
+ "{}: [SubscriptionLifetime]: (RecordSubscription#onNext#vistor) @ {} id: {} (Subscription ObjectId: {}) -- "
+ + "A problem occurred while validating sequence numbers: {} on subscription {}",
+ parent.shardId, connectionStartedAt, subscribeToShardId, System.identityHashCode(subscription),
+ iae.getMessage(), iae);
+ parent.errorOccurred(flow, iae);
+ return false;
+ }
+ return true;
+ }
+
+ private Map recordsNotForShard(SubscribeToShardEvent event) {
+ return event.records().stream().map(r -> {
+ Optional res = sequenceNumberValidator.shardIdFor(r.sequenceNumber());
+ if (!res.isPresent()) {
+ throw new IllegalArgumentException("Unable to validate sequence number of " + r.sequenceNumber());
+ }
+ return res.get();
+ }).filter(s -> !StringUtils.equalsIgnoreCase(s, parent.shardId))
+ .collect(Collectors.groupingBy(Function.identity())).entrySet().stream()
+ .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().size()));
+ }
+
@Override
public void onError(Throwable t) {
log.debug("{}: [SubscriptionLifetime]: (RecordSubscription#onError) @ {} id: {} -- {}: {}", parent.shardId,
diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java
index eea61250..46be077f 100644
--- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java
+++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java
@@ -15,8 +15,11 @@
package software.amazon.kinesis.retrieval.fanout;
+import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
+import lombok.Setter;
+import lombok.experimental.Accessors;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.leases.ShardInfo;
@@ -27,10 +30,14 @@ import software.amazon.kinesis.retrieval.RetrievalFactory;
@RequiredArgsConstructor
@KinesisClientInternalApi
+@Accessors(fluent = true)
public class FanOutRetrievalFactory implements RetrievalFactory {
private final KinesisAsyncClient kinesisClient;
private final String consumerArn;
+ @Getter
+ @Setter
+ private boolean validateRecordsAreForShard = false;
@Override
public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo,
@@ -41,6 +48,6 @@ public class FanOutRetrievalFactory implements RetrievalFactory {
@Override
public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo,
final MetricsFactory metricsFactory) {
- return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), consumerArn);
+ return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), consumerArn, validateRecordsAreForShard);
}
}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java
index 4e8f69d1..a2ed3208 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/SequenceNumberValidatorTest.java
@@ -14,113 +14,88 @@
*/
package software.amazon.kinesis.checkpoint;
-//@RunWith(MockitoJUnitRunner.class)
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Optional;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.not;
+import static org.junit.Assert.assertThat;
+
public class SequenceNumberValidatorTest {
- /*private final String streamName = "testStream";
- private final boolean validateWithGetIterator = true;
- private final String shardId = "shardid-123";
- @Mock
- private AmazonKinesis amazonKinesis;
+ private SequenceNumberValidator validator;
- @Test (expected = IllegalArgumentException.class)
- public final void testSequenceNumberValidator() {
- Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName,
- shardId, validateWithGetIterator);
+ @Before
+ public void begin() {
+ validator = new SequenceNumberValidator();
+ }
- String goodSequence = "456";
- String iterator = "happyiterator";
- String badSequence = "789";
- ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetShardIteratorRequest.class);
+ @Test
+ public void matchingSequenceNumberTest() {
+ String sequenceNumber = "49587497311274533994574834252742144236107130636007899138";
+ String expectedShardId = "shardId-000000000000";
- when(amazonKinesis.getShardIterator(requestCaptor.capture()))
- .thenReturn(new GetShardIteratorResult().withShardIterator(iterator))
- .thenThrow(new InvalidArgumentException(""));
+ Optional version = validator.versionFor(sequenceNumber);
+ assertThat(version, equalTo(Optional.of(2)));
- validator.validateSequenceNumber(goodSequence);
- try {
- validator.validateSequenceNumber(badSequence);
- } finally {
- final List requests = requestCaptor.getAllValues();
- assertEquals(2, requests.size());
+ Optional shardId = validator.shardIdFor(sequenceNumber);
+ assertThat(shardId, equalTo(Optional.of(expectedShardId)));
- final GetShardIteratorRequest goodRequest = requests.get(0);
- final GetShardIteratorRequest badRequest = requests.get(0);
-
- assertEquals(streamName, goodRequest.getStreamName());
- assertEquals(shardId, goodRequest.shardId());
- assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), goodRequest.getShardIteratorType());
- assertEquals(goodSequence, goodRequest.getStartingSequenceNumber());
-
- assertEquals(streamName, badRequest.getStreamName());
- assertEquals(shardId, badRequest.shardId());
- assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), badRequest.getShardIteratorType());
- assertEquals(goodSequence, badRequest.getStartingSequenceNumber());
- }
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.of(true)));
}
@Test
- public final void testNoValidation() {
- Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName,
- shardId, !validateWithGetIterator);
- String sequenceNumber = "456";
+ public void shardMismatchTest() {
+ String sequenceNumber = "49585389983312162443796657944872008114154899568972529698";
+ String invalidShardId = "shardId-000000000001";
- // Just checking that the false flag for validating against getIterator is honored
- validator.validateSequenceNumber(sequenceNumber);
+ Optional version = validator.versionFor(sequenceNumber);
+ assertThat(version, equalTo(Optional.of(2)));
- verify(amazonKinesis, never()).getShardIterator(any(GetShardIteratorRequest.class));
+ Optional shardId = validator.shardIdFor(sequenceNumber);
+ assertThat(shardId, not(equalTo(invalidShardId)));
+
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, invalidShardId), equalTo(Optional.of(false)));
}
@Test
- public void nonNumericValueValidationTest() {
- Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName,
- shardId, validateWithGetIterator);
+ public void versionMismatchTest() {
+ String sequenceNumber = "74107425965128755728308386687147091174006956590945533954";
+ String expectedShardId = "shardId-000000000000";
- String[] nonNumericStrings = {null,
- "bogus-sequence-number",
- SentinelCheckpoint.LATEST.toString(),
- SentinelCheckpoint.TRIM_HORIZON.toString(),
- SentinelCheckpoint.AT_TIMESTAMP.toString()};
+ Optional version = validator.versionFor(sequenceNumber);
+ assertThat(version, equalTo(Optional.empty()));
- Arrays.stream(nonNumericStrings).forEach(sequenceNumber -> {
- try {
- validator.validateSequenceNumber(sequenceNumber);
- fail("Validator should not consider " + sequenceNumber + " a valid sequence number");
- } catch (IllegalArgumentException e) {
- // Do nothing
- }
- });
+ Optional shardId = validator.shardIdFor(sequenceNumber);
+ assertThat(shardId, equalTo(Optional.empty()));
- verify(amazonKinesis, never()).getShardIterator(any(GetShardIteratorRequest.class));
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
}
@Test
- public final void testIsDigits() {
- // Check things that are all digits
- String[] stringsOfDigits = {"0", "12", "07897803434", "12324456576788"};
+ public void sequenceNumberToShortTest() {
+ String sequenceNumber = "4958538998331216244379665794487200811415489956897252969";
+ String expectedShardId = "shardId-000000000000";
+
+ assertThat(validator.versionFor(sequenceNumber), equalTo(Optional.empty()));
+ assertThat(validator.shardIdFor(sequenceNumber), equalTo(Optional.empty()));
+
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
+ }
+
+ @Test
+ public void sequenceNumberToLongTest() {
+ String sequenceNumber = "495874973112745339945748342527421442361071306360078991381";
+ String expectedShardId = "shardId-000000000000";
+
+ assertThat(validator.versionFor(sequenceNumber), equalTo(Optional.empty()));
+ assertThat(validator.shardIdFor(sequenceNumber), equalTo(Optional.empty()));
+
+ assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
+ }
+
- for (String digits : stringsOfDigits) {
- assertTrue("Expected that " + digits + " would be considered a string of digits.",
- Checkpoint.SequenceNumberValidator.isDigits(digits));
- }
- // Check things that are not all digits
- String[] stringsWithNonDigits = {
- null,
- "",
- " ", // white spaces
- "6 4",
- "\t45",
- "5242354235234\n",
- "7\n6\n5\n",
- "12s", // last character
- "c07897803434", // first character
- "1232445wef6576788", // interior
- "no-digits",
- };
- for (String notAllDigits : stringsWithNonDigits) {
- assertFalse("Expected that " + notAllDigits + " would not be considered a string of digits.",
- Checkpoint.SequenceNumberValidator.isDigits(notAllDigits));
- }
- }*/
}
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java
index 33045149..6284c6d5 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java
@@ -1,6 +1,8 @@
package software.amazon.kinesis.retrieval.fanout;
+import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
@@ -11,6 +13,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.ArrayList;
@@ -50,8 +53,9 @@ import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
@Slf4j
public class FanOutRecordsPublisherTest {
- private static final String SHARD_ID = "Shard-001";
+ private static final String SHARD_ID = "shardId-000000000001";
private static final String CONSUMER_ARN = "arn:consumer";
+ private static final boolean VALIDATE_RECORD_SHARD_MATCHING = true;
@Mock
private KinesisAsyncClient kinesisClient;
@@ -66,7 +70,8 @@ public class FanOutRecordsPublisherTest {
@Test
public void simpleTest() throws Exception {
- FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
+ VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor flowCaptor = ArgumentCaptor
@@ -133,7 +138,8 @@ public class FanOutRecordsPublisherTest {
@Test
public void largeRequestTest() throws Exception {
- FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
+ VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor flowCaptor = ArgumentCaptor
@@ -200,7 +206,8 @@ public class FanOutRecordsPublisherTest {
@Test
public void testResourceNotFoundForShard() {
- FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
+ VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor flowCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordFlow.class);
@@ -224,7 +231,8 @@ public class FanOutRecordsPublisherTest {
@Test
public void testContinuesAfterSequence() {
- FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN);
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
+ VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor captor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordSubscription.class);
@@ -297,6 +305,66 @@ public class FanOutRecordsPublisherTest {
}
+ @Test
+ public void mismatchedShardIdTest() {
+ FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN, true);
+
+ ArgumentCaptor captor = ArgumentCaptor
+ .forClass(FanOutRecordsPublisher.RecordSubscription.class);
+ ArgumentCaptor flowCaptor = ArgumentCaptor
+ .forClass(FanOutRecordsPublisher.RecordFlow.class);
+
+ doNothing().when(publisher).subscribe(captor.capture());
+
+ source.start(ExtendedSequenceNumber.LATEST,
+ InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));
+
+ List errorsHandled = new ArrayList<>();
+ List inputsReceived = new ArrayList<>();
+
+ source.subscribe(new Subscriber() {
+ Subscription subscription;
+
+ @Override
+ public void onSubscribe(Subscription s) {
+ subscription = s;
+ subscription.request(1);
+ }
+
+ @Override
+ public void onNext(ProcessRecordsInput input) {
+ inputsReceived.add(input);
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ errorsHandled.add(t);
+
+ }
+
+ @Override
+ public void onComplete() {
+ fail("OnComplete called when not expected");
+ }
+ });
+
+ verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
+ flowCaptor.getValue().onEventStream(publisher);
+ captor.getValue().onSubscribe(subscription);
+
+ List records = Stream.of(1, 2, 3).map(seq -> makeRecord(seq, seq)).collect(Collectors.toList());
+
+ batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build();
+
+ captor.getValue().onNext(batchEvent);
+
+ verify(subscription, times(1)).request(1);
+ assertThat(inputsReceived.size(), equalTo(0));
+ assertThat(errorsHandled.size(), equalTo(1));
+ assertThat(errorsHandled.get(0), instanceOf(IllegalArgumentException.class));
+ assertThat(errorsHandled.get(0).getMessage(), containsString("Received records destined for different shards"));
+ }
+
private void verifyRecords(List clientRecordsList, List matchers) {
assertThat(clientRecordsList.size(), equalTo(matchers.size()));
for (int i = 0; i < clientRecordsList.size(); ++i) {
@@ -333,9 +401,17 @@ public class FanOutRecordsPublisherTest {
}
private Record makeRecord(int sequenceNumber) {
+ return makeRecord(sequenceNumber, 1);
+ }
+
+ private Record makeRecord(int sequenceNumber, int shardId) {
+ BigInteger version = BigInteger.valueOf(2).shiftLeft(184);
+ BigInteger shard = BigInteger.valueOf(shardId).shiftLeft(4);
+ BigInteger seq = version.add(shard).add(BigInteger.valueOf(sequenceNumber));
+
SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 });
return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now())
- .sequenceNumber(Integer.toString(sequenceNumber)).partitionKey("A").build();
+ .sequenceNumber(seq.toString()).partitionKey("A").build();
}
private static class KinesisClientRecordMatcher extends TypeSafeDiagnosingMatcher {