* fixed memory leak in `StreamARNUtil` (new class)
* substantial DRY
* added more, and enhanced recently-provided, unit tests
This commit is contained in:
stair 2023-04-18 16:29:24 -04:00 committed by GitHub
parent 52e34dbe8f
commit 0fd94acb2b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 187 additions and 135 deletions

View file

@ -1,86 +1,68 @@
package software.amazon.kinesis.common;
import com.google.common.base.Joiner;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;
import java.util.HashMap;
import java.util.Optional;
@Slf4j
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class StreamARNUtil {
private static final HashMap<String, Arn> streamARNCache = new HashMap<>();
/**
* Caches an {@link Arn} constructed from a {@link StsClient#getCallerIdentity()} call.
*/
private static final SupplierCache<Arn> CALLER_IDENTITY_ARN = new SupplierCache<>(() -> {
try (final SdkHttpClient httpClient = UrlConnectionHttpClient.builder().build();
final StsClient stsClient = StsClient.builder().httpClient(httpClient).build()) {
final GetCallerIdentityResponse response = stsClient.getCallerIdentity();
return Arn.fromString(response.arn());
} catch (AwsServiceException | SdkClientException e) {
log.warn("Unable to get sts caller identity to build stream arn", e);
return null;
}
});
/**
* This static method attempts to retrieve the stream ARN using the stream name, region, and accountId returned by STS
* It is designed to fail gracefully, returning Optional.empty() if any errors occur.
* @param streamName: stream name
* @param kinesisRegion: kinesisRegion is a nullable parameter used to construct the stream arn
* @return
*
* @param streamName stream name
* @param kinesisRegion kinesisRegion is a nullable parameter used to construct the stream arn
*/
public static Optional<Arn> getStreamARN(String streamName, Region kinesisRegion) {
return getStreamARN(streamName, kinesisRegion, Optional.empty());
}
public static Optional<Arn> getStreamARN(String streamName, Region kinesisRegion, @NonNull Optional<String> accountId) {
if (kinesisRegion == null || StringUtils.isEmpty(kinesisRegion.toString())) {
if (kinesisRegion == null) {
return Optional.empty();
}
// Consult the cache before contacting STS
String key = getCacheKey(streamName, kinesisRegion, accountId);
if (streamARNCache.containsKey(key)) {
return Optional.of(streamARNCache.get(key));
}
Optional<Arn> stsCallerArn = getStsCallerArn();
if (!stsCallerArn.isPresent() || !stsCallerArn.get().accountId().isPresent()) {
final Arn identityArn = CALLER_IDENTITY_ARN.get();
if (identityArn == null) {
return Optional.empty();
}
accountId = accountId.isPresent() ? accountId : stsCallerArn.get().accountId();
Arn kinesisStreamArn = Arn.builder()
.partition(stsCallerArn.get().partition())
// the provided accountId takes precedence
final String chosenAccountId = accountId.orElse(identityArn.accountId().orElse(""));
return Optional.of(Arn.builder()
.partition(identityArn.partition())
.service("kinesis")
.region(kinesisRegion.toString())
.accountId(accountId.get())
.accountId(chosenAccountId)
.resource("stream/" + streamName)
.build();
// Update the cache
streamARNCache.put(key, kinesisStreamArn);
return Optional.of(kinesisStreamArn);
}
private static Optional<Arn> getStsCallerArn() {
GetCallerIdentityResponse response;
try {
response = getStsClient().getCallerIdentity();
} catch (AwsServiceException | SdkClientException e) {
log.warn("Unable to get sts caller identity to build stream arn", e);
return Optional.empty();
}
return Optional.of(Arn.fromString(response.arn()));
}
private static StsClient getStsClient() {
return StsClient.builder()
.httpClient(UrlConnectionHttpClient.builder().build())
.build();
}
private static String getCacheKey(
String streamName, @NonNull Region kinesisRegion, @NonNull Optional<String> accountId) {
return Joiner.on(":").join(streamName, kinesisRegion.toString(), accountId.orElse(""));
.build());
}
}

View file

@ -8,106 +8,173 @@ import org.mockito.MockitoAnnotations;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.powermock.reflect.Whitebox;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import java.lang.reflect.Field;
import java.util.Optional;
import java.util.function.Supplier;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ StreamARNUtil.class })
@PrepareForTest({ StreamARNUtil.class, StsClient.class, UrlConnectionHttpClient.class })
public class StreamARNUtilTest {
private static final String STS_RESPONSE_ARN_FORMAT = "arn:aws:sts::%s:assumed-role/Admin/alias";
private static final String KINESIS_STREAM_ARN_FORMAT = "arn:aws:kinesis:us-east-1:%s:stream/%s";
// To prevent clashes in the stream arn cache with identical names,
// we're using the test name as the stream name (key)
private static final Supplier<String> streamNameProvider = () -> Thread.currentThread().getStackTrace()[2].getMethodName();
/**
* Original {@link SupplierCache} that is constructed on class load.
*/
private static final SupplierCache<Arn> ORIGINAL_CACHE = Whitebox.getInternalState(
StreamARNUtil.class, "CALLER_IDENTITY_ARN");
private static final String ACCOUNT_ID = "12345";
private static final String STREAM_NAME = StreamARNUtilTest.class.getSimpleName();
@Mock
private StsClientBuilder mockStsClientBuilder;
@Mock
private StsClient mockStsClient;
private SupplierCache<Arn> spySupplierCache;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
PowerMockito.spy(StreamARNUtil.class);
PowerMockito.doReturn(mockStsClient).when(StreamARNUtil.class, "getStsClient");
setUpSupplierCache();
final Arn defaultArn = toArn(STS_RESPONSE_ARN_FORMAT, ACCOUNT_ID);
doReturn(defaultArn).when(spySupplierCache).get();
}
private void setUpSts() {
PowerMockito.mockStatic(StsClient.class);
PowerMockito.mockStatic(UrlConnectionHttpClient.class);
when(UrlConnectionHttpClient.builder()).thenReturn(mock(UrlConnectionHttpClient.Builder.class));
when(StsClient.builder()).thenReturn(mockStsClientBuilder);
when(mockStsClientBuilder.httpClient(any(SdkHttpClient.class))).thenReturn(mockStsClientBuilder);
when(mockStsClientBuilder.build()).thenReturn(mockStsClient);
// bypass the spy so the Sts clients are called
when(spySupplierCache.get()).thenCallRealMethod();
}
/**
* Wrap and embed the original {@link SupplierCache} with a spy to avoid
* one-and-done cache behavior, provide each test precise control over
* return values, and enable the ability to verify interactions via Mockito.
*/
private void setUpSupplierCache() throws Exception {
spySupplierCache = spy(ORIGINAL_CACHE);
final Field f = StreamARNUtil.class.getDeclaredField("CALLER_IDENTITY_ARN");
f.setAccessible(true);
f.set(null, spySupplierCache);
f.setAccessible(false);
}
@Test
public void testGetStreamARNHappyCase() {
String streamName = streamNameProvider.get();
String accountId = "123456789012";
when(mockStsClient.getCallerIdentity())
.thenReturn(GetCallerIdentityResponse.builder().arn(String.format(STS_RESPONSE_ARN_FORMAT, accountId)).build());
getStreamArn();
Optional<Arn> actualStreamARNOptional = StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1);
String expectedStreamARN = String.format(KINESIS_STREAM_ARN_FORMAT, accountId, streamName);
verify(mockStsClient, times(1)).getCallerIdentity();
assertTrue(actualStreamARNOptional.isPresent());
assertEquals(expectedStreamARN, actualStreamARNOptional.get().toString());
verify(spySupplierCache).get();
}
@Test
public void testGetStreamARNFromCache() {
String streamName = streamNameProvider.get();
String accountId = "123456789012";
when(mockStsClient.getCallerIdentity())
.thenReturn(GetCallerIdentityResponse.builder().arn(String.format(STS_RESPONSE_ARN_FORMAT, accountId)).build());
final Optional<Arn> actualStreamARNOptional1 = getStreamArn();
final Optional<Arn> actualStreamARNOptional2 = getStreamArn();
Optional<Arn> actualStreamARNOptional1 = StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1);
Optional<Arn> actualStreamARNOptional2 = StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1);
String expectedStreamARN = String.format(KINESIS_STREAM_ARN_FORMAT, accountId, streamName);
// Since the second ARN is obtained from the cache, hence there's only one sts call
verify(mockStsClient, times(1)).getCallerIdentity();
assertEquals(expectedStreamARN, actualStreamARNOptional1.get().toString());
verify(spySupplierCache, times(2)).get();
assertEquals(actualStreamARNOptional1, actualStreamARNOptional2);
}
@Test
public void testGetStreamARNReturnsEmptyOnSTSError() {
setUpSts();
// Optional.empty() is expected when there is an error with the STS call and STS returns empty Arn
String streamName = streamNameProvider.get();
when(mockStsClient.getCallerIdentity())
.thenThrow(AwsServiceException.builder().message("testAwsServiceException").build())
.thenThrow(SdkClientException.builder().message("testSdkClientException").build());
assertFalse(StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1).isPresent());
assertFalse(StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1).isPresent());
assertEquals(Optional.empty(), StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1));
assertEquals(Optional.empty(), StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1));
verify(mockStsClient, times(2)).getCallerIdentity();
verify(spySupplierCache, times(2)).get();
}
@Test
public void testGetStreamARNReturnsEmptyOnInvalidKinesisRegion() {
// Optional.empty() is expected when kinesis region is not set correctly
String streamName = streamNameProvider.get();
Optional<Arn> actualStreamARNOptional = StreamARNUtil.getStreamARN(streamName, null);
verify(mockStsClient, times(0)).getCallerIdentity();
assertFalse(actualStreamARNOptional.isPresent());
Optional<Arn> actualStreamARNOptional = StreamARNUtil.getStreamARN(STREAM_NAME, null);
assertEquals(Optional.empty(), actualStreamARNOptional);
verifyZeroInteractions(mockStsClient);
verifyZeroInteractions(spySupplierCache);
}
@Test
public void testGetStreamARNWithProvidedAccountIDAndIgnoredSTSResult() throws Exception {
public void testGetStreamARNWithProvidedAccountIDAndIgnoredSTSResult() {
// If the account id is provided in the StreamIdentifier, it will override the result (account id) returned by sts
String streamName = streamNameProvider.get();
String stsAccountId = "111111111111";
String providedAccountId = "222222222222";
when(mockStsClient.getCallerIdentity())
.thenReturn(GetCallerIdentityResponse.builder().arn(String.format(STS_RESPONSE_ARN_FORMAT, stsAccountId)).build());
final String cachedAccountId = "111111111111";
final String providedAccountId = "222222222222";
Optional<Arn> actualStreamARNOptional = StreamARNUtil.getStreamARN(streamName, Region.US_EAST_1, Optional.of(providedAccountId));
String expectedStreamARN = String.format(KINESIS_STREAM_ARN_FORMAT, providedAccountId, streamName);
verify(mockStsClient, times(1)).getCallerIdentity();
final Arn cachedArn = toArn(STS_RESPONSE_ARN_FORMAT, cachedAccountId);
when(spySupplierCache.get()).thenReturn(cachedArn);
final Optional<Arn> actualStreamARNOptional = StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1,
Optional.of(providedAccountId));
final Arn expectedStreamARN = toArn(KINESIS_STREAM_ARN_FORMAT, providedAccountId, STREAM_NAME);
verify(spySupplierCache).get();
verifyZeroInteractions(mockStsClient);
assertTrue(actualStreamARNOptional.isPresent());
assertEquals(expectedStreamARN, actualStreamARNOptional.get().toString());
assertEquals(expectedStreamARN, actualStreamARNOptional.get());
}
@Test
public void testNoAccountId() {
final Arn arnWithoutAccountId = toArn(STS_RESPONSE_ARN_FORMAT, "");
when(spySupplierCache.get()).thenReturn(arnWithoutAccountId);
assertEquals(Optional.empty(), arnWithoutAccountId.accountId());
final Optional<Arn> actualArn = StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1);
assertTrue(actualArn.isPresent());
assertEquals(Optional.empty(), actualArn.get().accountId());
}
private static Optional<Arn> getStreamArn() {
final Optional<Arn> actualArn = StreamARNUtil.getStreamARN(STREAM_NAME, Region.US_EAST_1);
final Arn expectedArn = toArn(KINESIS_STREAM_ARN_FORMAT, ACCOUNT_ID, STREAM_NAME);
assertTrue(actualArn.isPresent());
assertEquals(expectedArn, actualArn.get());
return actualArn;
}
private static Arn toArn(final String format, final Object... params) {
return Arn.fromString(String.format(format, params));
}
}

View file

@ -1,7 +1,6 @@
package software.amazon.kinesis.common;
import com.google.common.base.Joiner;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
@ -11,84 +10,88 @@ import software.amazon.awssdk.regions.Region;
import java.util.Optional;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.when;
import static org.powermock.api.mockito.PowerMockito.mockStatic;
@RunWith(PowerMockRunner.class)
@PrepareForTest(StreamARNUtil.class)
public class StreamIdentifierTest {
private static final String streamName = "streamName";
private static final Region kinesisRegion = Region.US_WEST_1;
private static final String accountId = "111111111111";
private static final String epoch = "1680616058";
private static final String STREAM_NAME = "streamName";
private static final Region KINESIS_REGION = Region.US_WEST_1;
private static final String TEST_ACCOUNT_ID = "111111111111";
private static final String EPOCH = "1680616058";
private static final Arn DEFAULT_ARN = Arn.builder().partition("aws").service("kinesis")
.region(KINESIS_REGION.toString()).accountId(TEST_ACCOUNT_ID).resource("stream/" + STREAM_NAME).build();
@Before
public void setUp() {
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(anyString(), any(Region.class))).thenReturn(Optional.empty());
when(StreamARNUtil.getStreamARN(STREAM_NAME, KINESIS_REGION)).thenReturn(Optional.of(DEFAULT_ARN));
when(StreamARNUtil.getStreamARN(STREAM_NAME, KINESIS_REGION, Optional.of(TEST_ACCOUNT_ID)))
.thenReturn(Optional.of(DEFAULT_ARN));
}
@Test
public void testSingleStreamInstanceWithName() {
StreamIdentifier actualStreamIdentifier = StreamIdentifier.singleStreamInstance(streamName);
Assert.assertFalse(actualStreamIdentifier.streamCreationEpochOptional().isPresent());
Assert.assertFalse(actualStreamIdentifier.accountIdOptional().isPresent());
Assert.assertFalse(actualStreamIdentifier.streamARNOptional().isPresent());
Assert.assertEquals(streamName, actualStreamIdentifier.streamName());
StreamIdentifier actualStreamIdentifier = StreamIdentifier.singleStreamInstance(STREAM_NAME);
assertFalse(actualStreamIdentifier.streamCreationEpochOptional().isPresent());
assertFalse(actualStreamIdentifier.accountIdOptional().isPresent());
assertFalse(actualStreamIdentifier.streamARNOptional().isPresent());
assertEquals(STREAM_NAME, actualStreamIdentifier.streamName());
}
@Test
public void testSingleStreamInstanceWithNameAndRegion() {
Optional<Arn> arn = Optional.of(Arn.builder().partition("aws").service("kinesis")
.region(kinesisRegion.toString()).accountId("123").resource("stream/" + streamName).build());
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(eq(streamName), eq(kinesisRegion))).thenReturn(arn);
StreamIdentifier actualStreamIdentifier = StreamIdentifier.singleStreamInstance(streamName, kinesisRegion);
Assert.assertFalse(actualStreamIdentifier.streamCreationEpochOptional().isPresent());
Assert.assertFalse(actualStreamIdentifier.accountIdOptional().isPresent());
Assert.assertTrue(actualStreamIdentifier.streamARNOptional().isPresent());
Assert.assertEquals(arn, actualStreamIdentifier.streamARNOptional());
StreamIdentifier actualStreamIdentifier = StreamIdentifier.singleStreamInstance(STREAM_NAME, KINESIS_REGION);
assertFalse(actualStreamIdentifier.streamCreationEpochOptional().isPresent());
assertFalse(actualStreamIdentifier.accountIdOptional().isPresent());
assertTrue(actualStreamIdentifier.streamARNOptional().isPresent());
assertEquals(DEFAULT_ARN, actualStreamIdentifier.streamARNOptional().get());
}
@Test
public void testMultiStreamInstanceWithIdentifierSerialization() {
String epoch = "1680616058";
Optional<Arn> arn = Optional.ofNullable(Arn.builder().partition("aws").service("kinesis")
.accountId(accountId).region(kinesisRegion.toString()).resource("stream/" + streamName).build());
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(eq(streamName), any(), any())).thenReturn(arn);
StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance(
Joiner.on(":").join(accountId, streamName, epoch, kinesisRegion));
assertActualStreamIdentifierExpected(arn, actualStreamIdentifier);
String.join(":", TEST_ACCOUNT_ID, STREAM_NAME, EPOCH, KINESIS_REGION.toString()));
assertActualStreamIdentifierExpected(DEFAULT_ARN, actualStreamIdentifier);
}
@Test
public void testMultiStreamInstanceWithRegionSerialized() {
Region serializedRegion = Region.US_GOV_EAST_1;
Optional<Arn> arn = Optional.ofNullable(Arn.builder().partition("aws").service("kinesis")
.accountId(accountId).region(serializedRegion.toString()).resource("stream/" + streamName).build());
.accountId(TEST_ACCOUNT_ID).region(serializedRegion.toString()).resource("stream/" + STREAM_NAME).build());
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(eq(streamName), eq(serializedRegion), any())).thenReturn(arn);
when(StreamARNUtil.getStreamARN(eq(STREAM_NAME), eq(serializedRegion), any(Optional.class))).thenReturn(arn);
StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance(
Joiner.on(":").join(accountId, streamName, epoch, serializedRegion), kinesisRegion);
String.join(":", TEST_ACCOUNT_ID, STREAM_NAME, EPOCH, serializedRegion.toString()), KINESIS_REGION);
assertActualStreamIdentifierExpected(arn, actualStreamIdentifier);
}
@Test
public void testMultiStreamInstanceWithoutRegionSerialized() {
Optional<Arn> arn = Optional.ofNullable(Arn.builder().partition("aws").service("kinesis")
.accountId(accountId).region(kinesisRegion.toString()).resource("stream/" + streamName).build());
mockStatic(StreamARNUtil.class);
when(StreamARNUtil.getStreamARN(eq(streamName), eq(kinesisRegion), any())).thenReturn(arn);
StreamIdentifier actualStreamIdentifier = StreamIdentifier.multiStreamInstance(
Joiner.on(":").join(accountId, streamName, epoch), kinesisRegion);
assertActualStreamIdentifierExpected(arn, actualStreamIdentifier);
String.join(":", TEST_ACCOUNT_ID, STREAM_NAME, EPOCH), KINESIS_REGION);
assertActualStreamIdentifierExpected(DEFAULT_ARN, actualStreamIdentifier);
}
private void assertActualStreamIdentifierExpected(Optional<Arn> expectedARN, StreamIdentifier actual) {
Assert.assertTrue(actual.streamCreationEpochOptional().isPresent());
Assert.assertTrue(actual.accountIdOptional().isPresent());
Assert.assertTrue(actual.streamARNOptional().isPresent());
Assert.assertEquals(expectedARN, actual.streamARNOptional());
assertActualStreamIdentifierExpected(expectedARN.get(), actual);
}
private void assertActualStreamIdentifierExpected(Arn expectedARN, StreamIdentifier actual) {
assertTrue(actual.streamCreationEpochOptional().isPresent());
assertTrue(actual.accountIdOptional().isPresent());
assertTrue(actual.streamARNOptional().isPresent());
assertEquals(Optional.of(expectedARN), actual.streamARNOptional());
}
}