From ac7d975c5fd2c56e26b6f6aeac4ce0c59119df46 Mon Sep 17 00:00:00 2001 From: Ethan Katnic Date: Thu, 5 Sep 2024 09:39:22 -0700 Subject: [PATCH] Specify sts provider path for conversion to kclStsProvider --- .../AwsCredentialsProviderPropertyValueDecoder.java | 12 ++++++++---- ...SCredentialsProviderPropertyValueDecoderTest.java | 4 +++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/amazon-kinesis-client-multilang/src/main/java/software/amazon/kinesis/multilang/config/AwsCredentialsProviderPropertyValueDecoder.java b/amazon-kinesis-client-multilang/src/main/java/software/amazon/kinesis/multilang/config/AwsCredentialsProviderPropertyValueDecoder.java index 777448f8..fc0ff58f 100644 --- a/amazon-kinesis-client-multilang/src/main/java/software/amazon/kinesis/multilang/config/AwsCredentialsProviderPropertyValueDecoder.java +++ b/amazon-kinesis-client-multilang/src/main/java/software/amazon/kinesis/multilang/config/AwsCredentialsProviderPropertyValueDecoder.java @@ -27,6 +27,8 @@ import java.util.stream.Stream; import lombok.extern.slf4j.Slf4j; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; +import software.amazon.kinesis.multilang.auth.KclStsAssumeRoleCredentialsProvider; /** * Get AwsCredentialsProvider property. @@ -184,12 +186,14 @@ class AwsCredentialsProviderPropertyValueDecoder implements IPropertyValueDecode * or null if the class cannot be resolved or does not extend AwsCredentialsProvider. */ private static Class getClass(String providerName) { - String className = providerName.replace( - "software.amazon.awssdk.auth.credentials.StsAssumeRoleCredentialsProvider", - "software.amazon.kinesis.multilang.auth.KclStsAssumeRoleCredentialsProvider"); + // Convert any form of StsAssumeRoleCredentialsProvider string to KclStsAssumeRoleCredentialsProvider + if (providerName.equals(StsAssumeRoleCredentialsProvider.class.getSimpleName()) + || providerName.equals(StsAssumeRoleCredentialsProvider.class.getName())) { + providerName = KclStsAssumeRoleCredentialsProvider.class.getName(); + } final Class clazz; try { - final Class c = Class.forName(className); + final Class c = Class.forName(providerName); if (!AwsCredentialsProvider.class.isAssignableFrom(c)) { return null; } diff --git a/amazon-kinesis-client-multilang/src/test/java/software/amazon/kinesis/multilang/config/AWSCredentialsProviderPropertyValueDecoderTest.java b/amazon-kinesis-client-multilang/src/test/java/software/amazon/kinesis/multilang/config/AWSCredentialsProviderPropertyValueDecoderTest.java index 71f97d18..f56c5407 100644 --- a/amazon-kinesis-client-multilang/src/test/java/software/amazon/kinesis/multilang/config/AWSCredentialsProviderPropertyValueDecoderTest.java +++ b/amazon-kinesis-client-multilang/src/test/java/software/amazon/kinesis/multilang/config/AWSCredentialsProviderPropertyValueDecoderTest.java @@ -25,6 +25,7 @@ import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; import software.amazon.kinesis.multilang.auth.KclStsAssumeRoleCredentialsProvider; import static org.hamcrest.CoreMatchers.equalTo; @@ -121,7 +122,8 @@ public class AWSCredentialsProviderPropertyValueDecoderTest { for (final String className : Arrays.asList( KclStsAssumeRoleCredentialsProvider.class.getName(), // fully-qualified name KclStsAssumeRoleCredentialsProvider.class.getSimpleName(), // name-only; needs prefix - "software.amazon.awssdk.auth.credentials.StsAssumeRoleCredentialsProvider")) { + StsAssumeRoleCredentialsProvider.class.getName(), // user passes full sts package path + StsAssumeRoleCredentialsProvider.class.getSimpleName())) { final AwsCredentialsProvider provider = decoder.decodeValue(className + "|arn|sessionName"); assertNotNull(className, provider); }