diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamARNUtil.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamARNUtil.java index 03652a5b..667bf820 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamARNUtil.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamARNUtil.java @@ -2,7 +2,6 @@ package software.amazon.kinesis.common; import lombok.AccessLevel; import lombok.NoArgsConstructor; -import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import software.amazon.awssdk.arns.Arn; import software.amazon.awssdk.awscore.exception.AwsServiceException; @@ -26,7 +25,11 @@ public final class StreamARNUtil { try (final SdkHttpClient httpClient = UrlConnectionHttpClient.builder().build(); final StsClient stsClient = StsClient.builder().httpClient(httpClient).build()) { final GetCallerIdentityResponse response = stsClient.getCallerIdentity(); - return Arn.fromString(response.arn()); + final Arn arn = Arn.fromString(response.arn()); + + // guarantee the cached ARN will never have an empty accountId + arn.accountId().orElseThrow(() -> new IllegalStateException("AccountId is not present on " + arn)); + return arn; } catch (AwsServiceException | SdkClientException e) { log.warn("Unable to get sts caller identity to build stream arn", e); return null; @@ -34,17 +37,18 @@ public final class StreamARNUtil { }); /** - * This static method attempts to retrieve the stream ARN using the stream name, region, and accountId returned by STS + * Retrieves the stream ARN using the stream name, region, and accountId returned by STS. * It is designed to fail gracefully, returning Optional.empty() if any errors occur. * * @param streamName stream name - * @param kinesisRegion kinesisRegion is a nullable parameter used to construct the stream arn + * @param kinesisRegion Kinesis client endpoint, and also where the stream(s) to be + * processed are located. A null guarantees an empty ARN. */ public static Optional getStreamARN(String streamName, Region kinesisRegion) { - return getStreamARN(streamName, kinesisRegion, Optional.empty()); + return getStreamARN(streamName, kinesisRegion, null); } - public static Optional getStreamARN(String streamName, Region kinesisRegion, @NonNull Optional accountId) { + public static Optional getStreamARN(String streamName, Region kinesisRegion, String accountId) { if (kinesisRegion == null) { return Optional.empty(); } @@ -55,7 +59,7 @@ public final class StreamARNUtil { } // the provided accountId takes precedence - final String chosenAccountId = accountId.orElse(identityArn.accountId().orElse("")); + final String chosenAccountId = (accountId != null) ? accountId : identityArn.accountId().get(); return Optional.of(Arn.builder() .partition(identityArn.partition()) .service("kinesis") diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamIdentifier.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamIdentifier.java index 1a81f606..60195d4f 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamIdentifier.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamIdentifier.java @@ -48,14 +48,10 @@ public class StreamIdentifier { /** * Pattern for a serialized {@link StreamIdentifier}. The valid format is - * {@code ::[:]} where - * {@code region} is the id representation of a {@link Region} and is - * optional. + * {@code ::}. */ private static final Pattern STREAM_IDENTIFIER_PATTERN = Pattern.compile( - // `?::` has two parts: `?:` starts a non-capturing group, and - // `:` is the first character in the group (i.e., ":") - "(?[0-9]+):(?[^:]+):(?[0-9]+)(?::(?[-a-z0-9]+))?"); + "(?[0-9]+):(?[^:]+):(?[0-9]+)"); /** * Pattern for a stream ARN. The valid format is @@ -96,7 +92,6 @@ public class StreamIdentifier { final StringBuilder sb = new StringBuilder(accountIdOptional.get()).append(delimiter) .append(streamName).append(delimiter); streamCreationEpochOptional.ifPresent(sb::append); - streamARNOptional.flatMap(Arn::region).ifPresent(region -> sb.append(delimiter).append(region)); return sb.toString(); } @@ -121,8 +116,10 @@ public class StreamIdentifier { * Create a multi stream instance for StreamIdentifier from serialized stream identifier. * * @param serializationOrArn serialized {@link StreamIdentifier} or AWS ARN of a Kinesis stream - * @param kinesisRegion This nullable region is used to construct the optional StreamARN + * @param kinesisRegion Kinesis client endpoint, and also where the stream(s) to be + * processed are located. A null will default to the caller's region. * + * @see #multiStreamInstance(String) * @see #serialize() */ public static StreamIdentifier multiStreamInstance(String serializationOrArn, Region kinesisRegion) { @@ -142,6 +139,8 @@ public class StreamIdentifier { * Create a single stream instance for StreamIdentifier from stream name. * * @param streamNameOrArn stream name or AWS ARN of a Kinesis stream + * + * @see #singleStreamInstance(String, Region) */ public static StreamIdentifier singleStreamInstance(String streamNameOrArn) { return singleStreamInstance(streamNameOrArn, null); @@ -152,7 +151,10 @@ public class StreamIdentifier { * This method also constructs the optional StreamARN based on the region info. * * @param streamNameOrArn stream name or AWS ARN of a Kinesis stream - * @param kinesisRegion (optional) region used to construct the ARN + * @param kinesisRegion Kinesis client endpoint, and also where the stream(s) to be + * processed are located. A null will default to the caller's region. + * + * @see #singleStreamInstance(String) */ public static StreamIdentifier singleStreamInstance(String streamNameOrArn, Region kinesisRegion) { Validate.notEmpty(streamNameOrArn, "StreamName should not be empty"); @@ -172,7 +174,8 @@ public class StreamIdentifier { * Deserializes a StreamIdentifier from {@link #STREAM_IDENTIFIER_PATTERN}. * * @param input input string (e.g., ARN, serialized instance) to convert into an instance - * @param kinesisRegion (optional) region used to construct the ARN + * @param kinesisRegion Kinesis client endpoint, and also where the stream(s) to be + * processed are located. A null will default to the caller's region. * @return a StreamIdentifier instance if the pattern matched, otherwise null */ private static StreamIdentifier fromSerialization(final String input, final Region kinesisRegion) { @@ -185,31 +188,38 @@ public class StreamIdentifier { * Constructs a StreamIdentifier from {@link #STREAM_ARN_PATTERN}. * * @param input input string (e.g., ARN, serialized instance) to convert into an instance - * @param kinesisRegion (optional) region used to construct the ARN + * @param kinesisRegion Kinesis client endpoint, and also where the stream(s) to be + * processed are located. A null will default to the caller's region. * @return a StreamIdentifier instance if the pattern matched, otherwise null */ private static StreamIdentifier fromArn(final String input, final Region kinesisRegion) { final Matcher matcher = STREAM_ARN_PATTERN.matcher(input); - return matcher.matches() - ? toStreamIdentifier(matcher, "", kinesisRegion) : null; + if (matcher.matches()) { + final String arnRegion = matcher.group("region"); + final Region region = (arnRegion != null) ? Region.of(arnRegion) : kinesisRegion; + if ((kinesisRegion != null) && (region != kinesisRegion)) { + throw new IllegalArgumentException(String.format( + "Cannot create StreamIdentifier for a region other than %s: %s", kinesisRegion, input)); + } + return toStreamIdentifier(matcher, "", region); + } + return null; } private static StreamIdentifier toStreamIdentifier(final Matcher matcher, final String matchedEpoch, final Region kinesisRegion) { - final Optional accountId = Optional.of(matcher.group("accountId")); + final String accountId = matcher.group("accountId"); final String streamName = matcher.group("streamName"); final Optional creationEpoch = matchedEpoch.isEmpty() ? Optional.empty() : Optional.of(Long.valueOf(matchedEpoch)); - final String matchedRegion = matcher.group("region"); - final Region region = (matchedRegion != null) ? Region.of(matchedRegion) : kinesisRegion; - final Optional arn = StreamARNUtil.getStreamARN(streamName, region, accountId); + final Optional arn = StreamARNUtil.getStreamARN(streamName, kinesisRegion, accountId); if (!creationEpoch.isPresent() && !arn.isPresent()) { throw new IllegalArgumentException("Cannot create StreamIdentifier if missing both ARN and creation epoch"); } return StreamIdentifier.builder() - .accountIdOptional(accountId) + .accountIdOptional(Optional.of(accountId)) .streamName(streamName) .streamCreationEpochOptional(creationEpoch) .streamARNOptional(arn) diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamARNUtilTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamARNUtilTest.java index 595710c6..aef3974b 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamARNUtilTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamARNUtilTest.java @@ -17,6 +17,7 @@ import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.sts.StsClient; import software.amazon.awssdk.services.sts.StsClientBuilder; +import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse; import java.lang.reflect.Field; import java.util.Optional; @@ -60,7 +61,8 @@ public class StreamARNUtilTest { public void setUp() throws Exception { MockitoAnnotations.initMocks(this); - setUpSupplierCache(); + spySupplierCache = spy(ORIGINAL_CACHE); + setUpSupplierCache(spySupplierCache); final Arn defaultArn = toArn(STS_RESPONSE_ARN_FORMAT, ACCOUNT_ID); doReturn(defaultArn).when(spySupplierCache).get(); @@ -84,12 +86,10 @@ public class StreamARNUtilTest { * one-and-done cache behavior, provide each test precise control over * return values, and enable the ability to verify interactions via Mockito. */ - private void setUpSupplierCache() throws Exception { - spySupplierCache = spy(ORIGINAL_CACHE); - + static void setUpSupplierCache(final SupplierCache cache) throws Exception { final Field f = StreamARNUtil.class.getDeclaredField("CALLER_IDENTITY_ARN"); f.setAccessible(true); - f.set(null, spySupplierCache); + f.set(null, cache); f.setAccessible(false); } @@ -124,6 +124,24 @@ public class StreamARNUtilTest { verify(spySupplierCache, times(2)).get(); } + @Test(expected = IllegalStateException.class) + public void testStsResponseWithoutAccountId() { + setUpSts(); + + final Arn arnWithoutAccountId = toArn(STS_RESPONSE_ARN_FORMAT, ""); + assertEquals(Optional.empty(), arnWithoutAccountId.accountId()); + + final GetCallerIdentityResponse identityResponse = GetCallerIdentityResponse.builder() + .arn(arnWithoutAccountId.toString()).build(); + when(mockStsClient.getCallerIdentity()).thenReturn(identityResponse); + + try { + StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1); + } finally { + verify(mockStsClient).getCallerIdentity(); + } + } + @Test public void testGetStreamARNReturnsEmptyOnInvalidKinesisRegion() { // Optional.empty() is expected when kinesis region is not set correctly @@ -143,7 +161,7 @@ public class StreamARNUtilTest { when(spySupplierCache.get()).thenReturn(cachedArn); final Optional actualStreamARNOptional = StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1, - Optional.of(providedAccountId)); + providedAccountId); final Arn expectedStreamARN = toArn(KINESIS_STREAM_ARN_FORMAT, providedAccountId, STREAM_NAME); verify(spySupplierCache).get(); @@ -152,17 +170,6 @@ public class StreamARNUtilTest { assertEquals(expectedStreamARN, actualStreamARNOptional.get()); } - @Test - public void testNoAccountId() { - final Arn arnWithoutAccountId = toArn(STS_RESPONSE_ARN_FORMAT, ""); - when(spySupplierCache.get()).thenReturn(arnWithoutAccountId); - assertEquals(Optional.empty(), arnWithoutAccountId.accountId()); - - final Optional actualArn = StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1); - assertTrue(actualArn.isPresent()); - assertEquals(Optional.empty(), actualArn.get().accountId()); - } - private static Optional getStreamArn() { final Optional actualArn = StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1); final Arn expectedArn = toArn(KINESIS_STREAM_ARN_FORMAT, ACCOUNT_ID, STREAM_NAME); diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamIdentifierTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamIdentifierTest.java index 115cab03..b3f4991b 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamIdentifierTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/common/StreamIdentifierTest.java @@ -1,7 +1,7 @@ package software.amazon.kinesis.common; import org.junit.Assert; -import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -14,9 +14,7 @@ import java.util.Optional; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; +import static org.junit.Assert.assertNotEquals; import static org.mockito.Mockito.when; import static org.powermock.api.mockito.PowerMockito.mockStatic; import static org.powermock.api.mockito.PowerMockito.verifyStatic; @@ -32,14 +30,9 @@ public class StreamIdentifierTest { private static final Arn DEFAULT_ARN = toArn(KINESIS_REGION); - @Before - public void setUp() { - mockStatic(StreamARNUtil.class); - - when(getStreamARN(anyString(), any(Region.class))).thenReturn(Optional.empty()); - when(getStreamARN(STREAM_NAME, KINESIS_REGION)).thenReturn(Optional.of(DEFAULT_ARN)); - when(getStreamARN(STREAM_NAME, KINESIS_REGION, Optional.of(TEST_ACCOUNT_ID))) - .thenReturn(Optional.of(DEFAULT_ARN)); + @BeforeClass + public static void setUpBeforeClass() throws Exception { + StreamARNUtilTest.setUpSupplierCache(new SupplierCache<>(() -> DEFAULT_ARN)); } /** @@ -47,16 +40,9 @@ public class StreamIdentifierTest { */ @Test public void testMultiStreamDeserializationSuccess() { - for (final String pattern : Arrays.asList( - // arn examples - toArn(KINESIS_REGION).toString(), - // serialization examples - "123456789012:stream-name:123", - "123456789012:stream-name:123:" + Region.US_ISOB_EAST_1 - )) { - final StreamIdentifier si = StreamIdentifier.multiStreamInstance(pattern); - assertNotNull(si); - } + final StreamIdentifier siSerialized = StreamIdentifier.multiStreamInstance(serialize()); + assertEquals(Optional.of(EPOCH), siSerialized.streamCreationEpochOptional()); + assertActualStreamIdentifierExpected(null, siSerialized); } /** @@ -73,17 +59,12 @@ public class StreamIdentifierTest { "arn:aws:kinesis:region:123456789012:stream/", // missing stream-name // serialization examples ":stream-name:123", // missing account id - "123456789012:stream-name", // missing delimiter before creation epoch - "accountId:stream-name:123", // non-numeric account id // "123456789:stream-name:123", // account id not 12 digits "123456789abc:stream-name:123", // 12char alphanumeric account id "123456789012::123", // missing stream name + "123456789012:stream-name", // missing delimiter and creation epoch "123456789012:stream-name:", // missing creation epoch - "123456789012:stream-name::", // missing creation epoch; ':' for optional region yet missing region - "123456789012:stream-name::us-east-1", // missing creation epoch "123456789012:stream-name:abc", // non-numeric creation epoch - "123456789012:stream-name:abc:", // non-numeric creation epoch with ':' yet missing region - "123456789012:stream-name:123:", // ':' for optional region yet missing region "" )) { try { @@ -102,18 +83,22 @@ public class StreamIdentifierTest { final StreamIdentifier multi = StreamIdentifier.multiStreamInstance(arn.toString()); assertEquals(single, multi); - assertEquals(Optional.of(TEST_ACCOUNT_ID), single.accountIdOptional()); - assertEquals(STREAM_NAME, single.streamName()); - assertEquals(Optional.of(arn), single.streamARNOptional()); + assertEquals(Optional.empty(), single.streamCreationEpochOptional()); + assertActualStreamIdentifierExpected(arn, single); } @Test(expected = IllegalArgumentException.class) public void testInstanceWithoutEpochOrArn() { - when(getStreamARN(STREAM_NAME, KINESIS_REGION, Optional.of(TEST_ACCOUNT_ID))) + mockStatic(StreamARNUtil.class); + when(getStreamARN(STREAM_NAME, KINESIS_REGION, TEST_ACCOUNT_ID)) .thenReturn(Optional.empty()); - final Arn arn = toArn(KINESIS_REGION); - StreamIdentifier.singleStreamInstance(arn.toString()); + try { + StreamIdentifier.singleStreamInstance(DEFAULT_ARN.toString()); + } finally { + verifyStatic(StreamARNUtil.class); + getStreamARN(STREAM_NAME, KINESIS_REGION, TEST_ACCOUNT_ID); + } } @Test @@ -130,57 +115,52 @@ public class StreamIdentifierTest { StreamIdentifier actualStreamIdentifier = StreamIdentifier.singleStreamInstance(STREAM_NAME, KINESIS_REGION); assertFalse(actualStreamIdentifier.streamCreationEpochOptional().isPresent()); assertFalse(actualStreamIdentifier.accountIdOptional().isPresent()); + assertEquals(STREAM_NAME, actualStreamIdentifier.streamName()); assertEquals(Optional.of(DEFAULT_ARN), actualStreamIdentifier.streamARNOptional()); } @Test public void testMultiStreamInstanceWithIdentifierSerialization() { - StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance(serialize(KINESIS_REGION)); - assertActualStreamIdentifierExpected(actualStreamIdentifier); + StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance(serialize()); + assertActualStreamIdentifierExpected(null, actualStreamIdentifier); + assertEquals(Optional.of(EPOCH), actualStreamIdentifier.streamCreationEpochOptional()); } - @Test - public void testMultiStreamInstanceWithRegionSerialized() { - Region serializedRegion = Region.US_GOV_EAST_1; - final Optional arn = Optional.of(toArn(serializedRegion)); + /** + * When KCL's Kinesis endpoint is a region, it lacks visibility to streams + * in other regions. Therefore, when the endpoint and ARN conflict, an + * Exception should be thrown. + */ + @Test(expected = IllegalArgumentException.class) + public void testConflictOnRegions() { + final Region arnRegion = Region.US_GOV_EAST_1; + assertNotEquals(arnRegion, KINESIS_REGION); - when(getStreamARN(STREAM_NAME, serializedRegion, Optional.of(TEST_ACCOUNT_ID))).thenReturn(arn); - - final String expectedSerialization = serialize(serializedRegion); - StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance( - expectedSerialization, KINESIS_REGION); - assertActualStreamIdentifierExpected(arn, actualStreamIdentifier); - assertEquals(expectedSerialization, actualStreamIdentifier.serialize()); - verifyStatic(StreamARNUtil.class); - getStreamARN(STREAM_NAME, serializedRegion, Optional.of(TEST_ACCOUNT_ID)); + StreamIdentifier.multiStreamInstance(toArn(arnRegion).toString(), KINESIS_REGION); } @Test public void testMultiStreamInstanceWithoutRegionSerialized() { StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance( - serialize(null), KINESIS_REGION); + serialize(), KINESIS_REGION); assertActualStreamIdentifierExpected(actualStreamIdentifier); } private void assertActualStreamIdentifierExpected(StreamIdentifier actual) { - assertActualStreamIdentifierExpected(Optional.of(DEFAULT_ARN), actual); + assertActualStreamIdentifierExpected(DEFAULT_ARN, actual); } - private void assertActualStreamIdentifierExpected(Optional expectedArn, StreamIdentifier actual) { + private void assertActualStreamIdentifierExpected(Arn expectedArn, StreamIdentifier actual) { assertEquals(STREAM_NAME, actual.streamName()); - assertEquals(Optional.of(EPOCH), actual.streamCreationEpochOptional()); assertEquals(Optional.of(TEST_ACCOUNT_ID), actual.accountIdOptional()); - assertEquals(expectedArn, actual.streamARNOptional()); + assertEquals(Optional.ofNullable(expectedArn), actual.streamARNOptional()); } /** * Creates a pattern that matches {@link StreamIdentifier} serialization. - * - * @param region (optional) region to serialize */ - private static String serialize(final Region region) { - return String.join(":", TEST_ACCOUNT_ID, STREAM_NAME, Long.toString(EPOCH)) + - ((region == null) ? "" : ':' + region.toString()); + private static String serialize() { + return String.join(":", TEST_ACCOUNT_ID, STREAM_NAME, Long.toString(EPOCH)); } private static Arn toArn(final Region region) {