Support Validating Records are From to the Expected Shard (#400)

* SequenceNumberValidator for verifying shardId's

Added a SequenceNumberValidator that will can extract, and verify
shardId's from a v2 sequence number.

* Added documentation and bit length test

Added documentation for the public methods.
Added a bit length test for the reader that will reject sequence
numbers that don't fit the expectations.

* Added more comments and further document public operations

Added comments in the only SequenceNumberReader explaining how things
are expected to work.

Further documented the class and operations with expectations and outcomes.

* Added configuration to allow failing on mismatched records

Allow configuration which will cause the FanOutRecordsPublisher to
throw an exception when it detects records that aren't for the shard
it's processing.
This commit is contained in:
Justin Pfifer 2018-09-17 14:33:46 -07:00 committed by Sahil Palvia
parent e8735a4742
commit 01f5db8049
6 changed files with 435 additions and 99 deletions

View file

@ -0,0 +1,188 @@
/*
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package software.amazon.kinesis.checkpoint;
import java.math.BigInteger;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* This supports extracting the shardId from a sequence number.
*
* <h2>Warning</h2>
* <strong>Sequence numbers are an opaque value used by Kinesis, and maybe changed at any time. Should validation stop
* working you may need to update your version of the KCL</strong>
*
*/
public class SequenceNumberValidator {
@Data
@Accessors(fluent = true)
private static class SequenceNumberComponents {
final int version;
final int shardId;
}
private interface SequenceNumberReader {
Optional<SequenceNumberComponents> read(String sequenceNumber);
}
/**
* Reader for the v2 sequence number format. v1 sequence numbers are no longer used or available.
*/
private static class V2SequenceNumberReader implements SequenceNumberReader {
private static final int VERSION = 2;
private static final int EXPECTED_BIT_LENGTH = 186;
private static final int VERSION_OFFSET = 184;
private static final long VERSION_MASK = (1 << 4) - 1;
private static final int SHARD_ID_OFFSET = 4;
private static final long SHARD_ID_MASK = (1L << 32) - 1;
@Override
public Optional<SequenceNumberComponents> read(String sequenceNumberString) {
BigInteger sequenceNumber = new BigInteger(sequenceNumberString, 10);
//
// If the bit length of the sequence number isn't 186 it's impossible for the version numbers
// to be where we expect them. We treat this the same as an unknown version of the sequence number
//
// If the sequence number length isn't what we expect it's due to a new version of the sequence number or
// an invalid sequence number. This
//
if (sequenceNumber.bitLength() != EXPECTED_BIT_LENGTH) {
return Optional.empty();
}
//
// Read the 4 most significant bits of the sequence number, the 2 most significant bits are implicitly 0
// (2 == 0b0011). If the version number doesn't match we give up and say we can't parse the sequence number
//
int version = readOffset(sequenceNumber, VERSION_OFFSET, VERSION_MASK);
if (version != VERSION) {
return Optional.empty();
}
//
// If we get here the sequence number is big enough, and the version matches so the shardId should be valid.
//
int shardId = readOffset(sequenceNumber, SHARD_ID_OFFSET, SHARD_ID_MASK);
return Optional.of(new SequenceNumberComponents(version, shardId));
}
private int readOffset(BigInteger sequenceNumber, int offset, long mask) {
long value = sequenceNumber.shiftRight(offset).longValue() & mask;
return (int) value;
}
}
private static final List<SequenceNumberReader> SEQUENCE_NUMBER_READERS = Collections
.singletonList(new V2SequenceNumberReader());
private Optional<SequenceNumberComponents> retrieveComponentsFor(String sequenceNumber) {
return SEQUENCE_NUMBER_READERS.stream().map(r -> r.read(sequenceNumber)).filter(Optional::isPresent).map(Optional::get).findFirst();
}
/**
* Attempts to retrieve the version for a sequence number. If no reader can be found for the sequence number this
* will return an empty Optional.
*
* <p>
* <strong>This will return an empty Optional if the it's unable to extract the version number. This can occur for
* multiple reasons including:
* <ul>
* <li>Kinesis has started using a new version of sequence numbers</li>
* <li>The provided sequence number isn't a valid Kinesis sequence number.</li>
* </ul>
* </strong>
* </p>
*
* @param sequenceNumber
* the sequence number to extract the version from
* @return an Optional containing the version if a compatible sequence number reader can be found, an empty Optional
* otherwise.
*/
public Optional<Integer> versionFor(String sequenceNumber) {
return retrieveComponentsFor(sequenceNumber).map(SequenceNumberComponents::version);
}
/**
* Attempts to retrieve the shardId from a sequence number. If the version of the sequence number is unsupported
* this will return an empty optional.
*
* <strong>This will return an empty Optional if the sequence number isn't recognized. This can occur for multiple
* reasons including:
* <ul>
* <li>Kinesis has started using a new version of sequence numbers</li>
* <li>The provided sequence number isn't a valid Kinesis sequence number.</li>
* </ul>
* </strong>
* <p>
* This should always return a value if {@link #versionFor(String)} returns a value
* </p>
*
* @param sequenceNumber
* the sequence number to extract the shardId from
* @return an Optional containing the shardId if the version is supported, an empty Optional otherwise.
*/
public Optional<String> shardIdFor(String sequenceNumber) {
return retrieveComponentsFor(sequenceNumber).map(s -> String.format("shardId-%012d", s.shardId()));
}
/**
* Validates that the sequence number provided contains the given shardId. If the sequence number is unsupported
* this will return an empty Optional.
*
* <p>
* Validation of a sequence number will only occur if the sequence number can be parsed. It's possible to use
* {@link #versionFor(String)} to verify that the given sequence number is supported by this class. There are 3
* possible validation states:
* <dl>
* <dt>Some(True)</dt>
* <dd>The sequence number can be parsed, and the shardId matches the one in the sequence number</dd>
* <dt>Some(False)</dt>
* <dd>THe sequence number can be parsed, and the shardId doesn't match the one in the sequence number</dd>
* <dt>None</dt>
* <dd>It wasn't possible to parse the sequence number so the validity of the sequence number is unknown</dd>
* </dl>
* </p>
*
* <p>
* <strong>Handling unknown validation causes is application specific, and not specific handling is
* provided.</strong>
* </p>
*
* @param sequenceNumber
* the sequence number to verify the shardId
* @param shardId
* the shardId that the sequence is expected to contain
* @return true if the sequence number contains the shardId, false if it doesn't. If the sequence number version is
* unsupported this will return an empty Optional
*/
public Optional<Boolean> validateSequenceNumberForShard(String sequenceNumber, String shardId) {
return shardIdFor(sequenceNumber).map(s -> StringUtils.equalsIgnoreCase(s, shardId));
}
}

View file

@ -80,9 +80,15 @@ public class FanOutConfig implements RetrievalSpecificConfig {
*/ */
private long retryBackoffMillis = 1000; private long retryBackoffMillis = 1000;
/**
* Controls whether the {@link FanOutRecordsPublisher} will validate that all the records are from the shard it's
* processing.
*/
private boolean validateRecordsAreForShard = false;
@Override @Override
public RetrievalFactory retrievalFactory() { public RetrievalFactory retrievalFactory() {
return new FanOutRetrievalFactory(kinesisClient, getOrCreateConsumerArn()); return new FanOutRetrievalFactory(kinesisClient, getOrCreateConsumerArn()).validateRecordsAreForShard(validateRecordsAreForShard);
} }
private String getOrCreateConsumerArn() { private String getOrCreateConsumerArn() {

View file

@ -18,15 +18,20 @@ package software.amazon.kinesis.retrieval.fanout;
import java.time.Instant; import java.time.Instant;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Subscriber; import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription; import org.reactivestreams.Subscription;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
@ -35,6 +40,7 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler;
import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.checkpoint.SequenceNumberValidator;
import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.KinesisRequestsBuilder; import software.amazon.kinesis.common.KinesisRequestsBuilder;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
@ -43,7 +49,6 @@ import software.amazon.kinesis.retrieval.KinesisClientRecord;
import software.amazon.kinesis.retrieval.RecordsPublisher; import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
@RequiredArgsConstructor
@Slf4j @Slf4j
@KinesisClientInternalApi @KinesisClientInternalApi
public class FanOutRecordsPublisher implements RecordsPublisher { public class FanOutRecordsPublisher implements RecordsPublisher {
@ -51,6 +56,34 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
private final KinesisAsyncClient kinesis; private final KinesisAsyncClient kinesis;
private final String shardId; private final String shardId;
private final String consumerArn; private final String consumerArn;
private final boolean validateRecordShardMatching;
/**
* Creates a new FanOutRecordsPublisher.
* <p>
* This is deprecated and will be removed in a later release. Use
* {@link #FanOutRecordsPublisher(KinesisAsyncClient, String, String, boolean)} instead
* </p>
*
* @param kinesis
* the kinesis client to use for requests
* @param shardId
* the shardId to retrieve records for
* @param consumerArn
* the consumer to use when retrieving records
*/
@Deprecated
public FanOutRecordsPublisher(KinesisAsyncClient kinesis, String shardId, String consumerArn) {
this(kinesis, shardId, consumerArn, false);
}
public FanOutRecordsPublisher(KinesisAsyncClient kinesis, String shardId, String consumerArn,
boolean validateRecordShardMatching) {
this.kinesis = kinesis;
this.shardId = shardId;
this.consumerArn = consumerArn;
this.validateRecordShardMatching = validateRecordShardMatching;
}
private final Object lockObject = new Object(); private final Object lockObject = new Object();
@ -451,8 +484,15 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
@Override @Override
public void responseReceived(SubscribeToShardResponse response) { public void responseReceived(SubscribeToShardResponse response) {
log.debug("{}: [SubscriptionLifetime]: (RecordFlow#responseReceived) @ {} id: {} -- Response received", Optional<SdkHttpResponse> sdkHttpResponse = Optional.ofNullable(response)
parent.shardId, connectionStartedAt, subscribeToShardId); .flatMap(r -> Optional.ofNullable(r.sdkHttpResponse()));
Optional<String> requestId = sdkHttpResponse.flatMap(s -> s.firstMatchingHeader("x-amz-requestid"));
Optional<String> requestId2 = sdkHttpResponse.flatMap(s -> s.firstMatchingHeader("x-amz-id-2"));
log.debug(
"{}: [SubscriptionLifetime]: (RecordFlow#responseReceived) @ {} id: {} -- Response received -- rid: {} -- rid2: {}",
parent.shardId, connectionStartedAt, subscribeToShardId, requestId.orElse("None"),
requestId2.orElse("None"));
} }
@Override @Override
@ -548,6 +588,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
private final RecordFlow flow; private final RecordFlow flow;
private final Instant connectionStartedAt; private final Instant connectionStartedAt;
private final String subscribeToShardId; private final String subscribeToShardId;
private final SequenceNumberValidator sequenceNumberValidator = new SequenceNumberValidator();
private Subscription subscription; private Subscription subscription;
@ -594,8 +635,9 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
cancel(); cancel();
} }
log.debug( log.debug(
"{}: [SubscriptionLifetime]: (RecordSubscription#onSubscribe) @ {} id: {} -- Outstanding: {} items so requesting an item", "{}: [SubscriptionLifetime]: (RecordSubscription#onSubscribe) @ {} id: {} (Subscription ObjectId: {}) -- Outstanding: {} items so requesting an item",
parent.shardId, connectionStartedAt, subscribeToShardId, parent.availableQueueSpace); parent.shardId, connectionStartedAt, subscribeToShardId, System.identityHashCode(subscription),
parent.availableQueueSpace);
if (parent.availableQueueSpace > 0) { if (parent.availableQueueSpace > 0) {
request(1); request(1);
} }
@ -615,12 +657,54 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
recordBatchEvent.accept(new SubscribeToShardResponseHandler.Visitor() { recordBatchEvent.accept(new SubscribeToShardResponseHandler.Visitor() {
@Override @Override
public void visit(SubscribeToShardEvent event) { public void visit(SubscribeToShardEvent event) {
if (parent.validateRecordShardMatching && !areRecordsValid(event)) {
return;
}
flow.recordsReceived(event); flow.recordsReceived(event);
} }
}); });
} }
} }
private boolean areRecordsValid(SubscribeToShardEvent event) {
try {
Map<String, Integer> mismatchedRecords = recordsNotForShard(event);
if (mismatchedRecords.size() > 0) {
String mismatchReport = mismatchedRecords.entrySet().stream()
.map(e -> String.format("(%s -> %d)", e.getKey(), e.getValue()))
.collect(Collectors.joining(", "));
log.debug(
"{}: [SubscriptionLifetime]: (RecordSubscription#onNext#vistor) @ {} id: {} (Subscription ObjectId: {}) -- Failing subscription due to mismatches: [ {} ]",
parent.shardId, connectionStartedAt, subscribeToShardId,
System.identityHashCode(subscription), mismatchReport);
parent.errorOccurred(flow, new IllegalArgumentException(
"Received records destined for different shards: " + mismatchReport));
return false;
}
} catch (IllegalArgumentException iae) {
log.debug(
"{}: [SubscriptionLifetime]: (RecordSubscription#onNext#vistor) @ {} id: {} (Subscription ObjectId: {}) -- "
+ "A problem occurred while validating sequence numbers: {} on subscription {}",
parent.shardId, connectionStartedAt, subscribeToShardId, System.identityHashCode(subscription),
iae.getMessage(), iae);
parent.errorOccurred(flow, iae);
return false;
}
return true;
}
private Map<String, Integer> recordsNotForShard(SubscribeToShardEvent event) {
return event.records().stream().map(r -> {
Optional<String> res = sequenceNumberValidator.shardIdFor(r.sequenceNumber());
if (!res.isPresent()) {
throw new IllegalArgumentException("Unable to validate sequence number of " + r.sequenceNumber());
}
return res.get();
}).filter(s -> !StringUtils.equalsIgnoreCase(s, parent.shardId))
.collect(Collectors.groupingBy(Function.identity())).entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().size()));
}
@Override @Override
public void onError(Throwable t) { public void onError(Throwable t) {
log.debug("{}: [SubscriptionLifetime]: (RecordSubscription#onError) @ {} id: {} -- {}: {}", parent.shardId, log.debug("{}: [SubscriptionLifetime]: (RecordSubscription#onError) @ {} id: {} -- {}: {}", parent.shardId,

View file

@ -15,8 +15,11 @@
package software.amazon.kinesis.retrieval.fanout; package software.amazon.kinesis.retrieval.fanout;
import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardInfo;
@ -27,10 +30,14 @@ import software.amazon.kinesis.retrieval.RetrievalFactory;
@RequiredArgsConstructor @RequiredArgsConstructor
@KinesisClientInternalApi @KinesisClientInternalApi
@Accessors(fluent = true)
public class FanOutRetrievalFactory implements RetrievalFactory { public class FanOutRetrievalFactory implements RetrievalFactory {
private final KinesisAsyncClient kinesisClient; private final KinesisAsyncClient kinesisClient;
private final String consumerArn; private final String consumerArn;
@Getter
@Setter
private boolean validateRecordsAreForShard = false;
@Override @Override
public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo, public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo,
@ -41,6 +48,6 @@ public class FanOutRetrievalFactory implements RetrievalFactory {
@Override @Override
public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo, public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo,
final MetricsFactory metricsFactory) { final MetricsFactory metricsFactory) {
return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), consumerArn); return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), consumerArn, validateRecordsAreForShard);
} }
} }

View file

@ -14,113 +14,88 @@
*/ */
package software.amazon.kinesis.checkpoint; package software.amazon.kinesis.checkpoint;
//@RunWith(MockitoJUnitRunner.class) import org.junit.Before;
import org.junit.Test;
import java.util.Optional;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class SequenceNumberValidatorTest { public class SequenceNumberValidatorTest {
/*private final String streamName = "testStream";
private final boolean validateWithGetIterator = true;
private final String shardId = "shardid-123";
@Mock private SequenceNumberValidator validator;
private AmazonKinesis amazonKinesis;
@Test (expected = IllegalArgumentException.class) @Before
public final void testSequenceNumberValidator() { public void begin() {
Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName, validator = new SequenceNumberValidator();
shardId, validateWithGetIterator);
String goodSequence = "456";
String iterator = "happyiterator";
String badSequence = "789";
ArgumentCaptor<GetShardIteratorRequest> requestCaptor = ArgumentCaptor.forClass(GetShardIteratorRequest.class);
when(amazonKinesis.getShardIterator(requestCaptor.capture()))
.thenReturn(new GetShardIteratorResult().withShardIterator(iterator))
.thenThrow(new InvalidArgumentException(""));
validator.validateSequenceNumber(goodSequence);
try {
validator.validateSequenceNumber(badSequence);
} finally {
final List<GetShardIteratorRequest> requests = requestCaptor.getAllValues();
assertEquals(2, requests.size());
final GetShardIteratorRequest goodRequest = requests.get(0);
final GetShardIteratorRequest badRequest = requests.get(0);
assertEquals(streamName, goodRequest.getStreamName());
assertEquals(shardId, goodRequest.shardId());
assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), goodRequest.getShardIteratorType());
assertEquals(goodSequence, goodRequest.getStartingSequenceNumber());
assertEquals(streamName, badRequest.getStreamName());
assertEquals(shardId, badRequest.shardId());
assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), badRequest.getShardIteratorType());
assertEquals(goodSequence, badRequest.getStartingSequenceNumber());
} }
@Test
public void matchingSequenceNumberTest() {
String sequenceNumber = "49587497311274533994574834252742144236107130636007899138";
String expectedShardId = "shardId-000000000000";
Optional<Integer> version = validator.versionFor(sequenceNumber);
assertThat(version, equalTo(Optional.of(2)));
Optional<String> shardId = validator.shardIdFor(sequenceNumber);
assertThat(shardId, equalTo(Optional.of(expectedShardId)));
assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.of(true)));
} }
@Test @Test
public final void testNoValidation() { public void shardMismatchTest() {
Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName, String sequenceNumber = "49585389983312162443796657944872008114154899568972529698";
shardId, !validateWithGetIterator); String invalidShardId = "shardId-000000000001";
String sequenceNumber = "456";
// Just checking that the false flag for validating against getIterator is honored Optional<Integer> version = validator.versionFor(sequenceNumber);
validator.validateSequenceNumber(sequenceNumber); assertThat(version, equalTo(Optional.of(2)));
verify(amazonKinesis, never()).getShardIterator(any(GetShardIteratorRequest.class)); Optional<String> shardId = validator.shardIdFor(sequenceNumber);
assertThat(shardId, not(equalTo(invalidShardId)));
assertThat(validator.validateSequenceNumberForShard(sequenceNumber, invalidShardId), equalTo(Optional.of(false)));
} }
@Test @Test
public void nonNumericValueValidationTest() { public void versionMismatchTest() {
Checkpoint.SequenceNumberValidator validator = new Checkpoint.SequenceNumberValidator(amazonKinesis, streamName, String sequenceNumber = "74107425965128755728308386687147091174006956590945533954";
shardId, validateWithGetIterator); String expectedShardId = "shardId-000000000000";
String[] nonNumericStrings = {null, Optional<Integer> version = validator.versionFor(sequenceNumber);
"bogus-sequence-number", assertThat(version, equalTo(Optional.empty()));
SentinelCheckpoint.LATEST.toString(),
SentinelCheckpoint.TRIM_HORIZON.toString(),
SentinelCheckpoint.AT_TIMESTAMP.toString()};
Arrays.stream(nonNumericStrings).forEach(sequenceNumber -> { Optional<String> shardId = validator.shardIdFor(sequenceNumber);
try { assertThat(shardId, equalTo(Optional.empty()));
validator.validateSequenceNumber(sequenceNumber);
fail("Validator should not consider " + sequenceNumber + " a valid sequence number");
} catch (IllegalArgumentException e) {
// Do nothing
}
});
verify(amazonKinesis, never()).getShardIterator(any(GetShardIteratorRequest.class)); assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
} }
@Test @Test
public final void testIsDigits() { public void sequenceNumberToShortTest() {
// Check things that are all digits String sequenceNumber = "4958538998331216244379665794487200811415489956897252969";
String[] stringsOfDigits = {"0", "12", "07897803434", "12324456576788"}; String expectedShardId = "shardId-000000000000";
for (String digits : stringsOfDigits) { assertThat(validator.versionFor(sequenceNumber), equalTo(Optional.empty()));
assertTrue("Expected that " + digits + " would be considered a string of digits.", assertThat(validator.shardIdFor(sequenceNumber), equalTo(Optional.empty()));
Checkpoint.SequenceNumberValidator.isDigits(digits));
assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
} }
// Check things that are not all digits
String[] stringsWithNonDigits = { @Test
null, public void sequenceNumberToLongTest() {
"", String sequenceNumber = "495874973112745339945748342527421442361071306360078991381";
" ", // white spaces String expectedShardId = "shardId-000000000000";
"6 4",
"\t45", assertThat(validator.versionFor(sequenceNumber), equalTo(Optional.empty()));
"5242354235234\n", assertThat(validator.shardIdFor(sequenceNumber), equalTo(Optional.empty()));
"7\n6\n5\n",
"12s", // last character assertThat(validator.validateSequenceNumberForShard(sequenceNumber, expectedShardId), equalTo(Optional.empty()));
"c07897803434", // first character
"1232445wef6576788", // interior
"no-digits",
};
for (String notAllDigits : stringsWithNonDigits) {
assertFalse("Expected that " + notAllDigits + " would not be considered a string of digits.",
Checkpoint.SequenceNumberValidator.isDigits(notAllDigits));
} }
}*/
} }

View file

@ -1,6 +1,8 @@
package software.amazon.kinesis.retrieval.fanout; package software.amazon.kinesis.retrieval.fanout;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
@ -11,6 +13,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import java.math.BigInteger;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
@ -50,8 +53,9 @@ import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
@Slf4j @Slf4j
public class FanOutRecordsPublisherTest { public class FanOutRecordsPublisherTest {
private static final String SHARD_ID = "Shard-001"; private static final String SHARD_ID = "shardId-000000000001";
private static final String CONSUMER_ARN = "arn:consumer"; private static final String CONSUMER_ARN = "arn:consumer";
private static final boolean VALIDATE_RECORD_SHARD_MATCHING = true;
@Mock @Mock
private KinesisAsyncClient kinesisClient; private KinesisAsyncClient kinesisClient;
@ -66,7 +70,8 @@ public class FanOutRecordsPublisherTest {
@Test @Test
public void simpleTest() throws Exception { public void simpleTest() throws Exception {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class); ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
@ -133,7 +138,8 @@ public class FanOutRecordsPublisherTest {
@Test @Test
public void largeRequestTest() throws Exception { public void largeRequestTest() throws Exception {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class); ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
@ -200,7 +206,8 @@ public class FanOutRecordsPublisherTest {
@Test @Test
public void testResourceNotFoundForShard() { public void testResourceNotFoundForShard() {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordFlow.class); .forClass(FanOutRecordsPublisher.RecordFlow.class);
@ -224,7 +231,8 @@ public class FanOutRecordsPublisherTest {
@Test @Test
public void testContinuesAfterSequence() { public void testContinuesAfterSequence() {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN); FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN,
VALIDATE_RECORD_SHARD_MATCHING);
ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordSubscription.class); .forClass(FanOutRecordsPublisher.RecordSubscription.class);
@ -297,6 +305,66 @@ public class FanOutRecordsPublisherTest {
} }
@Test
public void mismatchedShardIdTest() {
FanOutRecordsPublisher source = new FanOutRecordsPublisher(kinesisClient, SHARD_ID, CONSUMER_ARN, true);
ArgumentCaptor<FanOutRecordsPublisher.RecordSubscription> captor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordSubscription.class);
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
.forClass(FanOutRecordsPublisher.RecordFlow.class);
doNothing().when(publisher).subscribe(captor.capture());
source.start(ExtendedSequenceNumber.LATEST,
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST));
List<Throwable> errorsHandled = new ArrayList<>();
List<ProcessRecordsInput> inputsReceived = new ArrayList<>();
source.subscribe(new Subscriber<ProcessRecordsInput>() {
Subscription subscription;
@Override
public void onSubscribe(Subscription s) {
subscription = s;
subscription.request(1);
}
@Override
public void onNext(ProcessRecordsInput input) {
inputsReceived.add(input);
}
@Override
public void onError(Throwable t) {
errorsHandled.add(t);
}
@Override
public void onComplete() {
fail("OnComplete called when not expected");
}
});
verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture());
flowCaptor.getValue().onEventStream(publisher);
captor.getValue().onSubscribe(subscription);
List<Record> records = Stream.of(1, 2, 3).map(seq -> makeRecord(seq, seq)).collect(Collectors.toList());
batchEvent = SubscribeToShardEvent.builder().millisBehindLatest(100L).records(records).build();
captor.getValue().onNext(batchEvent);
verify(subscription, times(1)).request(1);
assertThat(inputsReceived.size(), equalTo(0));
assertThat(errorsHandled.size(), equalTo(1));
assertThat(errorsHandled.get(0), instanceOf(IllegalArgumentException.class));
assertThat(errorsHandled.get(0).getMessage(), containsString("Received records destined for different shards"));
}
private void verifyRecords(List<KinesisClientRecord> clientRecordsList, List<KinesisClientRecordMatcher> matchers) { private void verifyRecords(List<KinesisClientRecord> clientRecordsList, List<KinesisClientRecordMatcher> matchers) {
assertThat(clientRecordsList.size(), equalTo(matchers.size())); assertThat(clientRecordsList.size(), equalTo(matchers.size()));
for (int i = 0; i < clientRecordsList.size(); ++i) { for (int i = 0; i < clientRecordsList.size(); ++i) {
@ -333,9 +401,17 @@ public class FanOutRecordsPublisherTest {
} }
private Record makeRecord(int sequenceNumber) { private Record makeRecord(int sequenceNumber) {
return makeRecord(sequenceNumber, 1);
}
private Record makeRecord(int sequenceNumber, int shardId) {
BigInteger version = BigInteger.valueOf(2).shiftLeft(184);
BigInteger shard = BigInteger.valueOf(shardId).shiftLeft(4);
BigInteger seq = version.add(shard).add(BigInteger.valueOf(sequenceNumber));
SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 }); SdkBytes buffer = SdkBytes.fromByteArray(new byte[] { 1, 2, 3 });
return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now()) return Record.builder().data(buffer).approximateArrivalTimestamp(Instant.now())
.sequenceNumber(Integer.toString(sequenceNumber)).partitionKey("A").build(); .sequenceNumber(seq.toString()).partitionKey("A").build();
} }
private static class KinesisClientRecordMatcher extends TypeSafeDiagnosingMatcher<KinesisClientRecord> { private static class KinesisClientRecordMatcher extends TypeSafeDiagnosingMatcher<KinesisClientRecord> {