diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java index 8856a4a0..b1057f13 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamConfig.java @@ -15,14 +15,15 @@ package software.amazon.kinesis.common; -import lombok.Value; +import lombok.Data; import lombok.experimental.Accessors; -@Value +@Data @Accessors(fluent = true) public class StreamConfig { - StreamIdentifier streamIdentifier; - InitialPositionInStreamExtended initialPositionInStreamExtended; + private final StreamIdentifier streamIdentifier; + private final InitialPositionInStreamExtended initialPositionInStreamExtended; + private String consumerArn; } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamIdentifier.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamIdentifier.java index 7a416c7a..1259a609 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamIdentifier.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/common/StreamIdentifier.java @@ -63,6 +63,7 @@ public class StreamIdentifier { /** * Create a multi stream instance for StreamIdentifier from serialized stream identifier. + * The serialized stream identifier should be of the format account:stream:creationepoch * @param streamIdentifierSer * @return StreamIdentifier */ diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java index e2f2f852..e196920d 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/coordinator/Scheduler.java @@ -891,7 +891,6 @@ public class Scheduler implements Runnable { protected ShardConsumer buildConsumer(@NonNull final ShardInfo shardInfo, @NonNull final ShardRecordProcessorFactory shardRecordProcessorFactory) { - RecordsPublisher cache = retrievalConfig.retrievalFactory().createGetRecordsCache(shardInfo, metricsFactory); ShardRecordProcessorCheckpointer checkpointer = coordinatorConfig.coordinatorFactory().createRecordProcessorCheckpointer(shardInfo, checkpoint); // The only case where streamName is not available will be when multistreamtracker not set. In this case, @@ -902,6 +901,7 @@ public class Scheduler implements Runnable { // to gracefully complete the reading. final StreamConfig streamConfig = currentStreamConfigMap.getOrDefault(streamIdentifier, getDefaultStreamConfig(streamIdentifier)); Validate.notNull(streamConfig, "StreamConfig should not be null"); + RecordsPublisher cache = retrievalConfig.retrievalFactory().createGetRecordsCache(shardInfo, streamConfig, metricsFactory); ShardConsumerArgument argument = new ShardConsumerArgument(shardInfo, streamConfig.streamIdentifier(), leaseCoordinator, diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java index 5f22411a..63ae7b5f 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalConfig.java @@ -121,6 +121,13 @@ public class RetrievalConfig { return this; } + public RetrievalConfig retrievalSpecificConfig(RetrievalSpecificConfig retrievalSpecificConfig) { + this.retrievalSpecificConfig = retrievalSpecificConfig; + validateFanoutConfig(); + validatePollingConfig(); + return this; + } + public RetrievalFactory retrievalFactory() { if (retrievalFactory == null) { if (retrievalSpecificConfig == null) { @@ -129,22 +136,36 @@ public class RetrievalConfig { retrievalSpecificConfig = appStreamTracker.map(multiStreamTracker -> retrievalSpecificConfig, streamConfig -> ((FanOutConfig) retrievalSpecificConfig).streamName(streamConfig.streamIdentifier().streamName())); } - retrievalFactory = retrievalSpecificConfig.retrievalFactory(); } - validateConfig(); return retrievalFactory; } - private void validateConfig() { + private void validateFanoutConfig() { + // If we are in multistream mode and if retrievalSpecificConfig is an instance of FanOutConfig and if consumerArn is set throw exception. + boolean isFanoutConfig = retrievalSpecificConfig instanceof FanOutConfig; + boolean isInvalidFanoutConfig = isFanoutConfig && appStreamTracker.map( + multiStreamTracker -> ((FanOutConfig) retrievalSpecificConfig).consumerArn() != null + || ((FanOutConfig) retrievalSpecificConfig).streamName() != null, + streamConfig -> streamConfig.streamIdentifier() == null + || streamConfig.streamIdentifier().streamName() == null); + if(isInvalidFanoutConfig) { + throw new IllegalArgumentException( + "Invalid config: Either in multi-stream mode with streamName/consumerArn configured or in single-stream mode with no streamName configured"); + } + } + + private void validatePollingConfig() { boolean isPollingConfig = retrievalSpecificConfig instanceof PollingConfig; - boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map(multiStreamTracker -> + boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map( + multiStreamTracker -> ((PollingConfig) retrievalSpecificConfig).streamName() != null, streamConfig -> streamConfig.streamIdentifier() == null || streamConfig.streamIdentifier().streamName() == null); - if(isInvalidPollingConfig) { - throw new IllegalArgumentException("Invalid config: multistream enabled with streamName or single stream with no streamName"); + if (isInvalidPollingConfig) { + throw new IllegalArgumentException( + "Invalid config: Either in multi-stream mode with streamName configured or in single-stream mode with no streamName configured"); } } } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalFactory.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalFactory.java index 4c8f6b68..5703e1af 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalFactory.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RetrievalFactory.java @@ -15,6 +15,7 @@ package software.amazon.kinesis.retrieval; +import software.amazon.kinesis.common.StreamConfig; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.metrics.MetricsFactory; @@ -24,5 +25,10 @@ import software.amazon.kinesis.metrics.MetricsFactory; public interface RetrievalFactory { GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(ShardInfo shardInfo, MetricsFactory metricsFactory); + @Deprecated RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, MetricsFactory metricsFactory); + + default RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, StreamConfig streamConfig, MetricsFactory metricsFactory) { + return createGetRecordsCache(shardInfo, metricsFactory); + } } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java index fafe7e18..9318b996 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConfig.java @@ -80,17 +80,11 @@ public class FanOutConfig implements RetrievalSpecificConfig { */ private long retryBackoffMillis = 1000; - @Override - public RetrievalFactory retrievalFactory() { - return new FanOutRetrievalFactory(kinesisClient, streamName, this::getOrCreateConsumerArn); + @Override public RetrievalFactory retrievalFactory() { + return new FanOutRetrievalFactory(kinesisClient, streamName, consumerArn, this::getOrCreateConsumerArn); } - // TODO : LTR. Need Stream Specific ConsumerArn to be passed from Customer private String getOrCreateConsumerArn(String streamName) { - if (consumerArn != null) { - return consumerArn; - } - FanOutConsumerRegistration registration = createConsumerRegistration(streamName); try { return registration.getOrCreateStreamConsumerArn(); diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConsumerRegistration.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConsumerRegistration.java index 0519390c..9bcdd83c 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConsumerRegistration.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutConsumerRegistration.java @@ -76,7 +76,7 @@ public class FanOutConsumerRegistration implements ConsumerRegistration { try { response = describeStreamConsumer(); } catch (ResourceNotFoundException e) { - log.info("StreamConsumer not found, need to create it."); + log.info("{} : StreamConsumer not found, need to create it.", streamName); } // 2. If not, register consumer @@ -92,7 +92,7 @@ public class FanOutConsumerRegistration implements ConsumerRegistration { break; } catch (LimitExceededException e) { // TODO: Figure out internal service exceptions - log.debug("RegisterStreamConsumer call got throttled will retry."); + log.debug("{} : RegisterStreamConsumer call got throttled will retry.", streamName); finalException = e; } retries--; @@ -104,7 +104,7 @@ public class FanOutConsumerRegistration implements ConsumerRegistration { } } catch (ResourceInUseException e) { // Consumer is present, call DescribeStreamConsumer - log.debug("Got ResourceInUseException consumer exists, will call DescribeStreamConsumer again."); + log.debug("{} : Got ResourceInUseException consumer exists, will call DescribeStreamConsumer again.", streamName); response = describeStreamConsumer(); } } @@ -160,17 +160,17 @@ public class FanOutConsumerRegistration implements ConsumerRegistration { while (!ConsumerStatus.ACTIVE.equals(status) && retries > 0) { status = describeStreamConsumer().consumerDescription().consumerStatus(); retries--; - log.info(String.format("Waiting for StreamConsumer %s to have ACTIVE status...", streamConsumerName)); + log.info("{} : Waiting for StreamConsumer {} to have ACTIVE status...", streamName, streamConsumerName); Thread.sleep(retryBackoffMillis); } } catch (InterruptedException ie) { - log.debug("Thread was interrupted while fetching StreamConsumer status, moving on."); + log.debug("{} : Thread was interrupted while fetching StreamConsumer status, moving on.", streamName); } if (!ConsumerStatus.ACTIVE.equals(status)) { final String message = String.format( - "Status of StreamConsumer %s, was not ACTIVE after all retries. Was instead %s.", - streamConsumerName, status); + "%s : Status of StreamConsumer %s, was not ACTIVE after all retries. Was instead %s.", + streamName, streamConsumerName, status); log.error(message); throw new IllegalStateException(message); } @@ -211,7 +211,7 @@ public class FanOutConsumerRegistration implements ConsumerRegistration { throw new DependencyException(e); } } catch (LimitExceededException e) { - log.info("Throttled while calling {} API, will backoff.", apiName); + log.info("{} : Throttled while calling {} API, will backoff.", streamName, apiName); try { Thread.sleep(retryBackoffMillis + (long) (Math.random() * 100)); } catch (InterruptedException ie) { @@ -224,7 +224,7 @@ public class FanOutConsumerRegistration implements ConsumerRegistration { if (finalException == null) { throw new IllegalStateException( - String.format("Finished all retries and no exception was caught while calling %s", apiName)); + String.format("%s : Finished all retries and no exception was caught while calling %s", streamName, apiName)); } throw finalException; diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java index 719d2e54..5796862b 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRetrievalFactory.java @@ -19,6 +19,7 @@ import lombok.NonNull; import lombok.RequiredArgsConstructor; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.kinesis.annotations.KinesisClientInternalApi; +import software.amazon.kinesis.common.StreamConfig; import software.amazon.kinesis.common.StreamIdentifier; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.metrics.MetricsFactory; @@ -37,8 +38,8 @@ public class FanOutRetrievalFactory implements RetrievalFactory { private final KinesisAsyncClient kinesisClient; private final String defaultStreamName; - private final Function consumerArnProvider; - private Map streamToConsumerArnMap = new HashMap<>(); + private final String defaultConsumerName; + private final Function consumerArnCreator; @Override public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo, @@ -48,19 +49,27 @@ public class FanOutRetrievalFactory implements RetrievalFactory { @Override public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo, + final StreamConfig streamConfig, final MetricsFactory metricsFactory) { final Optional streamIdentifierStr = shardInfo.streamIdentifierSerOpt(); final String streamName; if(streamIdentifierStr.isPresent()) { streamName = StreamIdentifier.multiStreamInstance(streamIdentifierStr.get()).streamName(); return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), - streamToConsumerArnMap.computeIfAbsent(streamName, consumerArnProvider::apply), + getOrCreateConsumerArn(streamName, streamConfig.consumerArn()), streamIdentifierStr.get()); } else { - streamName = defaultStreamName; return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), - streamToConsumerArnMap.computeIfAbsent(streamName, consumerArnProvider::apply)); + getOrCreateConsumerArn(defaultStreamName, defaultConsumerName)); } + } + @Override + public RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, MetricsFactory metricsFactory) { + throw new UnsupportedOperationException("FanoutRetrievalFactory needs StreamConfig Info"); + } + + private String getOrCreateConsumerArn(String streamName, String consumerArn) { + return consumerArn != null ? consumerArn : consumerArnCreator.apply(streamName); } } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/coordinator/SchedulerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/coordinator/SchedulerTest.java index e5a76ce3..1d24ae68 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/coordinator/SchedulerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/coordinator/SchedulerTest.java @@ -191,7 +191,7 @@ public class SchedulerTest { when(leaseCoordinator.leaseRefresher()).thenReturn(dynamoDBLeaseRefresher); when(shardSyncTaskManager.shardDetector()).thenReturn(shardDetector); when(shardSyncTaskManager.callShardSyncTask()).thenReturn(new TaskResult(null)); - when(retrievalFactory.createGetRecordsCache(any(ShardInfo.class), any(MetricsFactory.class))).thenReturn(recordsPublisher); + when(retrievalFactory.createGetRecordsCache(any(ShardInfo.class), any(StreamConfig.class), any(MetricsFactory.class))).thenReturn(recordsPublisher); when(shardDetector.streamIdentifier()).thenReturn(mock(StreamIdentifier.class)); scheduler = new Scheduler(checkpointConfig, coordinatorConfig, leaseManagementConfig, lifecycleConfig, diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java index 21228c75..4fee3d08 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutConfigTest.java @@ -25,12 +25,14 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.kinesis.common.StreamConfig; import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.exceptions.DependencyException; import software.amazon.kinesis.metrics.MetricsFactory; @@ -50,6 +52,13 @@ public class FanOutConfigTest { private FanOutConsumerRegistration consumerRegistration; @Mock private KinesisAsyncClient kinesisClient; + @Mock + private StreamConfig streamConfig; + + @Before + public void setup() { + when(streamConfig.consumerArn()).thenReturn(null); + } @Test public void testNoRegisterIfConsumerArnSet() throws Exception { @@ -68,7 +77,32 @@ public class FanOutConfigTest { ShardInfo shardInfo = mock(ShardInfo.class); // doReturn(Optional.of(StreamIdentifier.singleStreamInstance(TEST_STREAM_NAME).serialize())).when(shardInfo).streamIdentifier(); doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); - retrievalFactory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class)); + retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); + assertThat(retrievalFactory, not(nullValue())); + verify(consumerRegistration).getOrCreateStreamConsumerArn(); + } + + @Test + public void testRegisterNotCalledWhenConsumerArnSetInMultiStreamMode() throws Exception { + when(streamConfig.consumerArn()).thenReturn("consumerArn"); + FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) + .streamName(TEST_STREAM_NAME); + RetrievalFactory retrievalFactory = config.retrievalFactory(); + ShardInfo shardInfo = mock(ShardInfo.class); + doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt(); + retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); + assertThat(retrievalFactory, not(nullValue())); + verify(consumerRegistration, never()).getOrCreateStreamConsumerArn(); + } + + @Test + public void testRegisterCalledWhenConsumerArnNotSetInMultiStreamMode() throws Exception { + FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) + .streamName(TEST_STREAM_NAME); + RetrievalFactory retrievalFactory = config.retrievalFactory(); + ShardInfo shardInfo = mock(ShardInfo.class); + doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt(); + retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); assertThat(retrievalFactory, not(nullValue())); verify(consumerRegistration).getOrCreateStreamConsumerArn(); } @@ -94,7 +128,7 @@ public class FanOutConfigTest { RetrievalFactory factory = config.retrievalFactory(); ShardInfo shardInfo = mock(ShardInfo.class); doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); - factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class)); + factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); assertThat(factory, not(nullValue())); TestingConfig testingConfig = (TestingConfig) config; @@ -109,7 +143,7 @@ public class FanOutConfigTest { RetrievalFactory factory = config.retrievalFactory(); ShardInfo shardInfo = mock(ShardInfo.class); doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); - factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class)); + factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); assertThat(factory, not(nullValue())); TestingConfig testingConfig = (TestingConfig) config; assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME)); @@ -123,7 +157,7 @@ public class FanOutConfigTest { RetrievalFactory factory = config.retrievalFactory(); ShardInfo shardInfo = mock(ShardInfo.class); doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); - factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class)); + factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class)); assertThat(factory, not(nullValue())); TestingConfig testingConfig = (TestingConfig) config;