Internally construct StreamARN using STS (#1087)

Co-authored-by: Yu Zeng <yuzen@amazon.com>
Co-authored-by: stair <123031771+stair-aws@users.noreply.github.com>
This commit is contained in:
Yu Zeng 2023-04-18 12:26:31 -07:00 committed by GitHub
parent fc52976c3d
commit 52e34dbe8f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 435 additions and 40 deletions

View file

@ -21,7 +21,7 @@
<parent>
<artifactId>amazon-kinesis-client-pom</artifactId>
<groupId>software.amazon.kinesis</groupId>
<version>2.4.9-SNAPSHOT</version>
<version>2.5.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>

View file

@ -22,7 +22,7 @@
<parent>
<groupId>software.amazon.kinesis</groupId>
<artifactId>amazon-kinesis-client-pom</artifactId>
<version>2.4.9-SNAPSHOT</version>
<version>2.5.0-SNAPSHOT</version>
</parent>
<artifactId>amazon-kinesis-client</artifactId>
@ -75,6 +75,11 @@
<artifactId>netty-nio-client</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sts</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.glue</groupId>
<artifactId>schema-registry-serde</artifactId>
@ -134,6 +139,20 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-module-junit4</artifactId>
<version>1.7.4</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-api-mockito</artifactId>
<version>1.7.4</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>hamcrest-all</artifactId>

View file

@ -142,7 +142,7 @@ public class ConfigsBuilder {
@NonNull KinesisAsyncClient kinesisClient, @NonNull DynamoDbAsyncClient dynamoDBClient,
@NonNull CloudWatchAsyncClient cloudWatchClient, @NonNull String workerIdentifier,
@NonNull ShardRecordProcessorFactory shardRecordProcessorFactory) {
this(new SingleStreamTracker(streamName),
this(new SingleStreamTracker(streamName, kinesisClient.serviceClientConfiguration().region()),
applicationName,
kinesisClient,
dynamoDBClient,

View file

@ -0,0 +1,86 @@
package software.amazon.kinesis.common;
import com.google.common.base.Joiner;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;
import java.util.HashMap;
import java.util.Optional;
@Slf4j
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class StreamARNUtil {
private static final HashMap<String, Arn> streamARNCache = new HashMap<>();
/**
* This static method attempts to retrieve the stream ARN using the stream name, region, and accountId returned by STS
* It is designed to fail gracefully, returning Optional.empty() if any errors occur.
* @param streamName: stream name
* @param kinesisRegion: kinesisRegion is a nullable parameter used to construct the stream arn
* @return
*/
public static Optional<Arn> getStreamARN(String streamName, Region kinesisRegion) {
return getStreamARN(streamName, kinesisRegion, Optional.empty());
}
public static Optional<Arn> getStreamARN(String streamName, Region kinesisRegion, @NonNull Optional<String> accountId) {
if (kinesisRegion == null || StringUtils.isEmpty(kinesisRegion.toString())) {
return Optional.empty();
}
// Consult the cache before contacting STS
String key = getCacheKey(streamName, kinesisRegion, accountId);
if (streamARNCache.containsKey(key)) {
return Optional.of(streamARNCache.get(key));
}
Optional<Arn> stsCallerArn = getStsCallerArn();
if (!stsCallerArn.isPresent() || !stsCallerArn.get().accountId().isPresent()) {
return Optional.empty();
}
accountId = accountId.isPresent() ? accountId : stsCallerArn.get().accountId();
Arn kinesisStreamArn = Arn.builder()
.partition(stsCallerArn.get().partition())
.service("kinesis")
.region(kinesisRegion.toString())
.accountId(accountId.get())
.resource("stream/" + streamName)
.build();
// Update the cache
streamARNCache.put(key, kinesisStreamArn);
return Optional.of(kinesisStreamArn);
}
private static Optional<Arn> getStsCallerArn() {
GetCallerIdentityResponse response;
try {
response = getStsClient().getCallerIdentity();
} catch (AwsServiceException | SdkClientException e) {
log.warn("Unable to get sts caller identity to build stream arn", e);
return Optional.empty();
}
return Optional.of(Arn.fromString(response.arn()));
}
private static StsClient getStsClient() {
return StsClient.builder()
.httpClient(UrlConnectionHttpClient.builder().build())
.build();
}
private static String getCacheKey(
String streamName, @NonNull Region kinesisRegion, @NonNull Optional<String> accountId) {
return Joiner.on(":").join(streamName, kinesisRegion.toString(), accountId.orElse(""));
}
}

View file

@ -16,38 +16,39 @@
package software.amazon.kinesis.common;
import com.google.common.base.Joiner;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.ToString;
import lombok.experimental.Accessors;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.utils.Validate;
import java.util.Optional;
import java.util.regex.Pattern;
@EqualsAndHashCode @Getter @Accessors(fluent = true)
@Builder(access = AccessLevel.PRIVATE)
@EqualsAndHashCode
@Getter
@ToString
@Accessors(fluent = true)
public class StreamIdentifier {
private final Optional<String> accountIdOptional;
@Builder.Default
private final Optional<String> accountIdOptional = Optional.empty();
private final String streamName;
private final Optional<Long> streamCreationEpochOptional;
@Builder.Default
private final Optional<Long> streamCreationEpochOptional = Optional.empty();
@Builder.Default
private final Optional<Arn> streamARNOptional = Optional.empty();
private static final String DELIMITER = ":";
private static final Pattern PATTERN = Pattern.compile(".*" + ":" + ".*" + ":" + "[0-9]*");
private StreamIdentifier(@NonNull String accountId, @NonNull String streamName, @NonNull Long streamCreationEpoch) {
this.accountIdOptional = Optional.of(accountId);
this.streamName = streamName;
this.streamCreationEpochOptional = Optional.of(streamCreationEpoch);
}
private StreamIdentifier(@NonNull String streamName) {
this.accountIdOptional = Optional.empty();
this.streamName = streamName;
this.streamCreationEpochOptional = Optional.empty();
}
private static final Pattern PATTERN = Pattern.compile(".*" + ":" + ".*" + ":" + "[0-9]*" + ":?([a-z]{2}(-gov)?-[a-z]+-\\d{1})?");
/**
* Serialize the current StreamIdentifier instance.
* TODO: Consider appending region info for cross-account consumer support
* @return
*/
public String serialize() {
@ -63,14 +64,35 @@ public class StreamIdentifier {
/**
* Create a multi stream instance for StreamIdentifier from serialized stream identifier.
* The serialized stream identifier should be of the format account:stream:creationepoch
* See the format of a serialized stream identifier at {@link StreamIdentifier#multiStreamInstance(String, Region)}
* @param streamIdentifierSer
* @return StreamIdentifier
*/
public static StreamIdentifier multiStreamInstance(String streamIdentifierSer) {
return multiStreamInstance(streamIdentifierSer, null);
}
/**
* Create a multi stream instance for StreamIdentifier from serialized stream identifier.
* @param streamIdentifierSer The serialized stream identifier should be of the format
* account:stream:creationepoch[:region]
* @param kinesisRegion This nullable region is used to construct the optional StreamARN
* @return StreamIdentifier
*/
public static StreamIdentifier multiStreamInstance(String streamIdentifierSer, Region kinesisRegion) {
if (PATTERN.matcher(streamIdentifierSer).matches()) {
final String[] split = streamIdentifierSer.split(DELIMITER);
return new StreamIdentifier(split[0], split[1], Long.parseLong(split[2]));
final String streamName = split[1];
final Optional<String> accountId = Optional.ofNullable(split[0]);
StreamIdentifierBuilder builder = StreamIdentifier.builder()
.accountIdOptional(accountId)
.streamName(streamName)
.streamCreationEpochOptional(Optional.of(Long.parseLong(split[2])));
final Region region = (split.length == 4) ?
Region.of(split[3]) : // Use the region extracted from the serialized string, which matches the regex pattern
kinesisRegion; // Otherwise just use the provided region
final Optional<Arn> streamARN = StreamARNUtil.getStreamARN(streamName, region, accountId);
return builder.streamARNOptional(streamARN).build();
} else {
throw new IllegalArgumentException("Unable to deserialize StreamIdentifier from " + streamIdentifierSer);
}
@ -82,7 +104,21 @@ public class StreamIdentifier {
* @return StreamIdentifier
*/
public static StreamIdentifier singleStreamInstance(String streamName) {
return singleStreamInstance(streamName, null);
}
/**
* Create a single stream instance for StreamIdentifier from the provided stream name and kinesisRegion.
* This method also constructs the optional StreamARN based on the region info.
* @param streamName
* @param kinesisRegion
* @return StreamIdentifier
*/
public static StreamIdentifier singleStreamInstance(String streamName, Region kinesisRegion) {
Validate.notEmpty(streamName, "StreamName should not be empty");
return new StreamIdentifier(streamName);
return StreamIdentifier.builder()
.streamName(streamName)
.streamARNOptional(StreamARNUtil.getStreamARN(streamName, kinesisRegion))
.build();
}
}

View file

@ -222,11 +222,12 @@ public class KinesisShardDetector implements ShardDetector {
final boolean shouldPropagateResourceNotFoundException) {
ListShardsRequest.Builder builder = KinesisRequestsBuilder.listShardsRequestBuilder();
if (StringUtils.isEmpty(nextToken)) {
builder = builder.streamName(streamIdentifier.streamName()).shardFilter(shardFilter);
} else {
builder = builder.nextToken(nextToken);
}
builder.streamName(streamIdentifier.streamName()).shardFilter(shardFilter);
streamIdentifier.streamARNOptional().ifPresent(arn -> builder.streamARN(arn.toString()));
} else {
builder.nextToken(nextToken);
}
final ListShardsRequest request = builder.build();
log.info("Stream {}: listing shards with list shards request {}", streamIdentifier, request);
@ -308,11 +309,12 @@ public class KinesisShardDetector implements ShardDetector {
@Override
public List<ChildShard> getChildShards(final String shardId) throws InterruptedException, ExecutionException, TimeoutException {
final GetShardIteratorRequest getShardIteratorRequest = KinesisRequestsBuilder.getShardIteratorRequestBuilder()
final GetShardIteratorRequest.Builder requestBuilder = KinesisRequestsBuilder.getShardIteratorRequestBuilder()
.streamName(streamIdentifier.streamName())
.shardIteratorType(ShardIteratorType.LATEST)
.shardId(shardId)
.build();
.shardId(shardId);
streamIdentifier.streamARNOptional().ifPresent(arn -> requestBuilder.streamARN(arn.toString()));
final GetShardIteratorRequest getShardIteratorRequest = requestBuilder.build();
final GetShardIteratorResponse getShardIteratorResponse =
FutureUtils.resolveOrCancelFuture(kinesisClient.getShardIterator(getShardIteratorRequest), kinesisRequestTimeout);

View file

@ -21,6 +21,7 @@ import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import lombok.ToString;
import software.amazon.awssdk.regions.Region;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.common.StreamConfig;
import software.amazon.kinesis.common.StreamIdentifier;
@ -48,6 +49,10 @@ public class SingleStreamTracker implements StreamTracker {
this(StreamIdentifier.singleStreamInstance(streamName));
}
public SingleStreamTracker(String streamName, Region region) {
this(StreamIdentifier.singleStreamInstance(streamName, region));
}
public SingleStreamTracker(StreamIdentifier streamIdentifier) {
this(streamIdentifier, DEFAULT_POSITION_IN_STREAM);
}

View file

@ -49,7 +49,7 @@ public class RetrievalConfig {
*/
public static final String KINESIS_CLIENT_LIB_USER_AGENT = "amazon-kinesis-client-library-java";
public static final String KINESIS_CLIENT_LIB_USER_AGENT_VERSION = "2.4.9-SNAPSHOT";
public static final String KINESIS_CLIENT_LIB_USER_AGENT_VERSION = "2.5.0-SNAPSHOT";
/**
* Client used to make calls to Kinesis for records retrieval
@ -120,7 +120,9 @@ public class RetrievalConfig {
public RetrievalConfig(@NonNull KinesisAsyncClient kinesisAsyncClient, @NonNull String streamName,
@NonNull String applicationName) {
this(kinesisAsyncClient, new SingleStreamTracker(streamName), applicationName);
this(kinesisAsyncClient,
new SingleStreamTracker(streamName, kinesisAsyncClient.serviceClientConfiguration().region()),
applicationName);
}
public RetrievalConfig(@NonNull KinesisAsyncClient kinesisAsyncClient, @NonNull StreamTracker streamTracker,

View file

@ -238,12 +238,15 @@ public class KinesisDataFetcher implements DataFetcher {
GetShardIteratorRequest.Builder builder = KinesisRequestsBuilder.getShardIteratorRequestBuilder()
.streamName(streamIdentifier.streamName()).shardId(shardId);
streamIdentifier.streamARNOptional().ifPresent(arn -> builder.streamARN(arn.toString()));
GetShardIteratorRequest request;
if (isIteratorRestart) {
request = IteratorBuilder.reconnectRequest(builder, sequenceNumber, initialPositionInStream).build();
} else {
request = IteratorBuilder.request(builder, sequenceNumber, initialPositionInStream).build();
}
log.debug("[GetShardIterator] Request has parameters {}", request);
// TODO: Check if this metric is fine to be added
final MetricsScope metricsScope = MetricsUtil.createMetricsWithOperation(metricsFactory, OPERATION);
@ -315,9 +318,11 @@ public class KinesisDataFetcher implements DataFetcher {
}
@Override
public GetRecordsRequest getGetRecordsRequest(String nextIterator) {
return KinesisRequestsBuilder.getRecordsRequestBuilder().shardIterator(nextIterator)
.limit(maxRecords).build();
public GetRecordsRequest getGetRecordsRequest(String nextIterator) {
GetRecordsRequest.Builder builder = KinesisRequestsBuilder.getRecordsRequestBuilder()
.shardIterator(nextIterator).limit(maxRecords);
streamIdentifier.streamARNOptional().ifPresent(arn -> builder.streamARN(arn.toString()));
return builder.build();
}
@Override

View file

@ -33,12 +33,12 @@ import software.amazon.kinesis.processor.MultiStreamTracker;
import software.amazon.kinesis.processor.ShardRecordProcessorFactory;
import software.amazon.kinesis.processor.SingleStreamTracker;
import software.amazon.kinesis.processor.StreamTracker;
import software.amazon.kinesis.utils.MockObjectHelper;
@RunWith(MockitoJUnitRunner.class)
public class ConfigsBuilderTest {
@Mock
private KinesisAsyncClient mockKinesisClient;
private final KinesisAsyncClient mockKinesisClient = MockObjectHelper.createKinesisClient();
@Mock
private DynamoDbAsyncClient mockDynamoClient;

View file

@ -0,0 +1,113 @@
package software.amazon.kinesis.common;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;
import java.util.Optional;
import java.util.function.Supplier;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ StreamARNUtil.class })
public class StreamARNUtilTest {
private static final String STS_RESPONSE_ARN_FORMAT = "arn:aws:sts::%s:assumed-role/Admin/alias";
private static final String KINESIS_STREAM_ARN_FORMAT = "arn:aws:kinesis:us-east-1:%s:stream/%s";
// To prevent clashes in the stream arn cache with identical names,
// we're using the test name as the stream name (key)
private static final Supplier<String> streamNameProvider = () -> Thread.currentThread().getStackTrace()[2].getMethodName();
@Mock
private StsClient mockStsClient;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
PowerMockito.spy(StreamARNUtil.class);
PowerMockito.doReturn(mockStsClient).when(StreamARNUtil.class, "getStsClient");
}
@Test
public void testGetStreamARNHappyCase() {
String streamName = streamNameProvider.get();
String accountId = "123456789012";
when(mockStsClient.getCallerIdentity())
.thenReturn(GetCallerIdentityResponse.builder().arn(String.format(STS_RESPONSE_ARN_FORMAT, accountId)).build());
Optional<Arn> actualStreamARNOptional = StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1);
String expectedStreamARN = String.format(KINESIS_STREAM_ARN_FORMAT, accountId, streamName);
verify(mockStsClient, times(1)).getCallerIdentity();
assertTrue(actualStreamARNOptional.isPresent());
assertEquals(expectedStreamARN, actualStreamARNOptional.get().toString());
}
@Test
public void testGetStreamARNFromCache() {
String streamName = streamNameProvider.get();
String accountId = "123456789012";
when(mockStsClient.getCallerIdentity())
.thenReturn(GetCallerIdentityResponse.builder().arn(String.format(STS_RESPONSE_ARN_FORMAT, accountId)).build());
Optional<Arn> actualStreamARNOptional1 = StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1);
Optional<Arn> actualStreamARNOptional2 = StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1);
String expectedStreamARN = String.format(KINESIS_STREAM_ARN_FORMAT, accountId, streamName);
// Since the second ARN is obtained from the cache, hence there's only one sts call
verify(mockStsClient, times(1)).getCallerIdentity();
assertEquals(expectedStreamARN, actualStreamARNOptional1.get().toString());
assertEquals(actualStreamARNOptional1, actualStreamARNOptional2);
}
@Test
public void testGetStreamARNReturnsEmptyOnSTSError() {
// Optional.empty() is expected when there is an error with the STS call and STS returns empty Arn
String streamName = streamNameProvider.get();
when(mockStsClient.getCallerIdentity())
.thenThrow(AwsServiceException.builder().message("testAwsServiceException").build())
.thenThrow(SdkClientException.builder().message("testSdkClientException").build());
assertFalse(StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1).isPresent());
assertFalse(StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1).isPresent());
}
@Test
public void testGetStreamARNReturnsEmptyOnInvalidKinesisRegion() {
// Optional.empty() is expected when kinesis region is not set correctly
String streamName = streamNameProvider.get();
Optional<Arn> actualStreamARNOptional = StreamARNUtil.getStreamARN(streamName, null);
verify(mockStsClient, times(0)).getCallerIdentity();
assertFalse(actualStreamARNOptional.isPresent());
}
@Test
public void testGetStreamARNWithProvidedAccountIDAndIgnoredSTSResult() throws Exception {
// If the account id is provided in the StreamIdentifier, it will override the result (account id) returned by sts
String streamName = streamNameProvider.get();
String stsAccountId = "111111111111";
String providedAccountId = "222222222222";
when(mockStsClient.getCallerIdentity())
.thenReturn(GetCallerIdentityResponse.builder().arn(String.format(STS_RESPONSE_ARN_FORMAT, stsAccountId)).build());
Optional<Arn> actualStreamARNOptional = StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1, Optional.of(providedAccountId));
String expectedStreamARN = String.format(KINESIS_STREAM_ARN_FORMAT, providedAccountId, streamName);
verify(mockStsClient, times(1)).getCallerIdentity();
assertTrue(actualStreamARNOptional.isPresent());
assertEquals(expectedStreamARN, actualStreamARNOptional.get().toString());
}
}

View file

@ -0,0 +1,94 @@
package software.amazon.kinesis.common;
import com.google.common.base.Joiner;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.regions.Region;
import java.util.Optional;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.when;
import static org.powermock.api.mockito.PowerMockito.mockStatic;
@RunWith(PowerMockRunner.class)
@PrepareForTest(StreamARNUtil.class)
public class StreamIdentifierTest {
private static final String streamName = "streamName";
private static final Region kinesisRegion = Region.US_WEST_1;
private static final String accountId = "111111111111";
private static final String epoch = "1680616058";
@Test
public void testSingleStreamInstanceWithName() {
StreamIdentifier actualStreamIdentifier = StreamIdentifier.singleStreamInstance(streamName);
Assert.assertFalse(actualStreamIdentifier.streamCreationEpochOptional().isPresent());
Assert.assertFalse(actualStreamIdentifier.accountIdOptional().isPresent());
Assert.assertFalse(actualStreamIdentifier.streamARNOptional().isPresent());
Assert.assertEquals(streamName, actualStreamIdentifier.streamName());
}
@Test
public void testSingleStreamInstanceWithNameAndRegion() {
Optional<Arn> arn = Optional.of(Arn.builder().partition("aws").service("kinesis")
.region(kinesisRegion.toString()).accountId("123").resource("stream/" + streamName).build());
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(eq(streamName), eq(kinesisRegion))).thenReturn(arn);
StreamIdentifier actualStreamIdentifier = StreamIdentifier.singleStreamInstance(streamName, kinesisRegion);
Assert.assertFalse(actualStreamIdentifier.streamCreationEpochOptional().isPresent());
Assert.assertFalse(actualStreamIdentifier.accountIdOptional().isPresent());
Assert.assertTrue(actualStreamIdentifier.streamARNOptional().isPresent());
Assert.assertEquals(arn, actualStreamIdentifier.streamARNOptional());
}
@Test
public void testMultiStreamInstanceWithIdentifierSerialization() {
String epoch = "1680616058";
Optional<Arn> arn = Optional.ofNullable(Arn.builder().partition("aws").service("kinesis")
.accountId(accountId).region(kinesisRegion.toString()).resource("stream/" + streamName).build());
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(eq(streamName), any(), any())).thenReturn(arn);
StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance(
Joiner.on(":").join(accountId, streamName, epoch, kinesisRegion));
assertActualStreamIdentifierExpected(arn, actualStreamIdentifier);
}
@Test
public void testMultiStreamInstanceWithRegionSerialized() {
Region serializedRegion = Region.US_GOV_EAST_1;
Optional<Arn> arn = Optional.ofNullable(Arn.builder().partition("aws").service("kinesis")
.accountId(accountId).region(serializedRegion.toString()).resource("stream/" + streamName).build());
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(eq(streamName), eq(serializedRegion), any())).thenReturn(arn);
StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance(
Joiner.on(":").join(accountId, streamName, epoch, serializedRegion), kinesisRegion);
assertActualStreamIdentifierExpected(arn, actualStreamIdentifier);
}
@Test
public void testMultiStreamInstanceWithoutRegionSerialized() {
Optional<Arn> arn = Optional.ofNullable(Arn.builder().partition("aws").service("kinesis")
.accountId(accountId).region(kinesisRegion.toString()).resource("stream/" + streamName).build());
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(eq(streamName), eq(kinesisRegion), any())).thenReturn(arn);
StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance(
Joiner.on(":").join(accountId, streamName, epoch), kinesisRegion);
assertActualStreamIdentifierExpected(arn, actualStreamIdentifier);
}
private void assertActualStreamIdentifierExpected(Optional<Arn> expectedARN, StreamIdentifier actual) {
Assert.assertTrue(actual.streamCreationEpochOptional().isPresent());
Assert.assertTrue(actual.accountIdOptional().isPresent());
Assert.assertTrue(actual.streamARNOptional().isPresent());
Assert.assertEquals(expectedARN, actual.streamARNOptional());
}
}

View file

@ -112,6 +112,7 @@ import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.RetrievalConfig;
import software.amazon.kinesis.retrieval.RetrievalFactory;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
import software.amazon.kinesis.utils.MockObjectHelper;
/**
*
@ -137,7 +138,6 @@ public class SchedulerTest {
private ProcessorConfig processorConfig;
private RetrievalConfig retrievalConfig;
@Mock
private KinesisAsyncClient kinesisClient;
@Mock
private DynamoDbAsyncClient dynamoDBClient;
@ -180,6 +180,7 @@ public class SchedulerTest {
lifecycleConfig = new LifecycleConfig();
metricsConfig = new MetricsConfig(cloudWatchClient, namespace);
processorConfig = new ProcessorConfig(shardRecordProcessorFactory);
kinesisClient = MockObjectHelper.createKinesisClient();
retrievalConfig = new RetrievalConfig(kinesisClient, streamName, applicationName)
.retrievalFactory(retrievalFactory);
when(leaseCoordinator.leaseRefresher()).thenReturn(dynamoDBLeaseRefresher);

View file

@ -24,13 +24,13 @@ import software.amazon.kinesis.common.StreamConfig;
import software.amazon.kinesis.processor.MultiStreamTracker;
import software.amazon.kinesis.processor.SingleStreamTracker;
import software.amazon.kinesis.processor.StreamTracker;
import software.amazon.kinesis.utils.MockObjectHelper;
@RunWith(MockitoJUnitRunner.class)
public class RetrievalConfigTest {
private static final String APPLICATION_NAME = RetrievalConfigTest.class.getSimpleName();
@Mock
private KinesisAsyncClient mockKinesisClient;
@Mock
@ -38,6 +38,7 @@ public class RetrievalConfigTest {
@Before
public void setUp() {
mockKinesisClient = MockObjectHelper.createKinesisClient(true);
when(mockMultiStreamTracker.isMultiStream()).thenReturn(true);
}

View file

@ -0,0 +1,31 @@
package software.amazon.kinesis.utils;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.KinesisServiceClientConfiguration;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public final class MockObjectHelper {
public static KinesisAsyncClient createKinesisClient() {
return createKinesisClient(Region.US_EAST_1);
}
/**
* @param isRegionDummy a boolean to determine whether to use a null value for the Kinesis client's region.
* @return
*/
public static KinesisAsyncClient createKinesisClient(boolean isRegionDummy) {
return isRegionDummy ? createKinesisClient(null) : createKinesisClient();
}
public static KinesisAsyncClient createKinesisClient(Region region) {
KinesisAsyncClient kinesisClient = mock(KinesisAsyncClient.class);
when(kinesisClient.serviceClientConfiguration()).
thenReturn(KinesisServiceClientConfiguration.builder().region(region).build());
return kinesisClient;
}
}

View file

@ -22,7 +22,7 @@
<artifactId>amazon-kinesis-client-pom</artifactId>
<packaging>pom</packaging>
<name>Amazon Kinesis Client Library</name>
<version>2.4.9-SNAPSHOT</version>
<version>2.5.0-SNAPSHOT</version>
<description>The Amazon Kinesis Client Library for Java enables Java developers to easily consume and process data
from Amazon Kinesis.
</description>