Adding dedicated consumerArn support for streams in multistreaming mode

This commit is contained in:
Ashwin Giridharan 2020-06-03 00:48:11 -07:00
parent 113029e33c
commit f69398a2b2
9 changed files with 83 additions and 29 deletions

View file

@ -15,14 +15,15 @@
package software.amazon.kinesis.common; package software.amazon.kinesis.common;
import lombok.Value; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
@Value @Data
@Accessors(fluent = true) @Accessors(fluent = true)
public class StreamConfig { public class StreamConfig {
StreamIdentifier streamIdentifier; private final StreamIdentifier streamIdentifier;
InitialPositionInStreamExtended initialPositionInStreamExtended; private final InitialPositionInStreamExtended initialPositionInStreamExtended;
private String consumerArn;
} }

View file

@ -63,6 +63,7 @@ public class StreamIdentifier {
/** /**
* Create a multi stream instance for StreamIdentifier from serialized stream identifier. * Create a multi stream instance for StreamIdentifier from serialized stream identifier.
* The serialized stream identifier should be of the format account:stream:creationepoch
* @param streamIdentifierSer * @param streamIdentifierSer
* @return StreamIdentifier * @return StreamIdentifier
*/ */

View file

@ -891,7 +891,6 @@ public class Scheduler implements Runnable {
protected ShardConsumer buildConsumer(@NonNull final ShardInfo shardInfo, protected ShardConsumer buildConsumer(@NonNull final ShardInfo shardInfo,
@NonNull final ShardRecordProcessorFactory shardRecordProcessorFactory) { @NonNull final ShardRecordProcessorFactory shardRecordProcessorFactory) {
RecordsPublisher cache = retrievalConfig.retrievalFactory().createGetRecordsCache(shardInfo, metricsFactory);
ShardRecordProcessorCheckpointer checkpointer = coordinatorConfig.coordinatorFactory().createRecordProcessorCheckpointer(shardInfo, ShardRecordProcessorCheckpointer checkpointer = coordinatorConfig.coordinatorFactory().createRecordProcessorCheckpointer(shardInfo,
checkpoint); checkpoint);
// The only case where streamName is not available will be when multistreamtracker not set. In this case, // The only case where streamName is not available will be when multistreamtracker not set. In this case,
@ -902,6 +901,7 @@ public class Scheduler implements Runnable {
// to gracefully complete the reading. // to gracefully complete the reading.
final StreamConfig streamConfig = currentStreamConfigMap.getOrDefault(streamIdentifier, getDefaultStreamConfig(streamIdentifier)); final StreamConfig streamConfig = currentStreamConfigMap.getOrDefault(streamIdentifier, getDefaultStreamConfig(streamIdentifier));
Validate.notNull(streamConfig, "StreamConfig should not be null"); Validate.notNull(streamConfig, "StreamConfig should not be null");
RecordsPublisher cache = retrievalConfig.retrievalFactory().createGetRecordsCache(shardInfo, streamConfig, metricsFactory);
ShardConsumerArgument argument = new ShardConsumerArgument(shardInfo, ShardConsumerArgument argument = new ShardConsumerArgument(shardInfo,
streamConfig.streamIdentifier(), streamConfig.streamIdentifier(),
leaseCoordinator, leaseCoordinator,

View file

@ -121,6 +121,13 @@ public class RetrievalConfig {
return this; return this;
} }
public RetrievalConfig retrievalSpecificConfig(RetrievalSpecificConfig retrievalSpecificConfig) {
this.retrievalSpecificConfig = retrievalSpecificConfig;
validateFanoutConfig();
validatePollingConfig();
return this;
}
public RetrievalFactory retrievalFactory() { public RetrievalFactory retrievalFactory() {
if (retrievalFactory == null) { if (retrievalFactory == null) {
if (retrievalSpecificConfig == null) { if (retrievalSpecificConfig == null) {
@ -129,22 +136,36 @@ public class RetrievalConfig {
retrievalSpecificConfig = appStreamTracker.map(multiStreamTracker -> retrievalSpecificConfig, retrievalSpecificConfig = appStreamTracker.map(multiStreamTracker -> retrievalSpecificConfig,
streamConfig -> ((FanOutConfig) retrievalSpecificConfig).streamName(streamConfig.streamIdentifier().streamName())); streamConfig -> ((FanOutConfig) retrievalSpecificConfig).streamName(streamConfig.streamIdentifier().streamName()));
} }
retrievalFactory = retrievalSpecificConfig.retrievalFactory(); retrievalFactory = retrievalSpecificConfig.retrievalFactory();
} }
validateConfig();
return retrievalFactory; return retrievalFactory;
} }
private void validateConfig() { private void validateFanoutConfig() {
// If we are in multistream mode and if retrievalSpecificConfig is an instance of FanOutConfig and if consumerArn is set throw exception.
boolean isFanoutConfig = retrievalSpecificConfig instanceof FanOutConfig;
boolean isInvalidFanoutConfig = isFanoutConfig && appStreamTracker.map(
multiStreamTracker -> ((FanOutConfig) retrievalSpecificConfig).consumerArn() != null
|| ((FanOutConfig) retrievalSpecificConfig).streamName() != null,
streamConfig -> streamConfig.streamIdentifier() == null
|| streamConfig.streamIdentifier().streamName() == null);
if(isInvalidFanoutConfig) {
throw new IllegalArgumentException(
"Invalid config: Either in multi-stream mode with streamName/consumerArn configured or in single-stream mode with no streamName configured");
}
}
private void validatePollingConfig() {
boolean isPollingConfig = retrievalSpecificConfig instanceof PollingConfig; boolean isPollingConfig = retrievalSpecificConfig instanceof PollingConfig;
boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map(multiStreamTracker -> boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map(
multiStreamTracker ->
((PollingConfig) retrievalSpecificConfig).streamName() != null, ((PollingConfig) retrievalSpecificConfig).streamName() != null,
streamConfig -> streamConfig ->
streamConfig.streamIdentifier() == null || streamConfig.streamIdentifier().streamName() == null); streamConfig.streamIdentifier() == null || streamConfig.streamIdentifier().streamName() == null);
if(isInvalidPollingConfig) { if (isInvalidPollingConfig) {
throw new IllegalArgumentException("Invalid config: multistream enabled with streamName or single stream with no streamName"); throw new IllegalArgumentException(
"Invalid config: Either in multi-stream mode with streamName configured or in single-stream mode with no streamName configured");
} }
} }
} }

View file

@ -15,6 +15,7 @@
package software.amazon.kinesis.retrieval; package software.amazon.kinesis.retrieval;
import software.amazon.kinesis.common.StreamConfig;
import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsFactory;
@ -24,5 +25,10 @@ import software.amazon.kinesis.metrics.MetricsFactory;
public interface RetrievalFactory { public interface RetrievalFactory {
GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(ShardInfo shardInfo, MetricsFactory metricsFactory); GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(ShardInfo shardInfo, MetricsFactory metricsFactory);
@Deprecated
RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, MetricsFactory metricsFactory); RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, MetricsFactory metricsFactory);
default RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, StreamConfig streamConfig, MetricsFactory metricsFactory) {
return createGetRecordsCache(shardInfo, metricsFactory);
}
} }

View file

@ -80,17 +80,11 @@ public class FanOutConfig implements RetrievalSpecificConfig {
*/ */
private long retryBackoffMillis = 1000; private long retryBackoffMillis = 1000;
@Override @Override public RetrievalFactory retrievalFactory() {
public RetrievalFactory retrievalFactory() { return new FanOutRetrievalFactory(kinesisClient, streamName, consumerArn, this::getOrCreateConsumerArn);
return new FanOutRetrievalFactory(kinesisClient, streamName, this::getOrCreateConsumerArn);
} }
// TODO : LTR. Need Stream Specific ConsumerArn to be passed from Customer
private String getOrCreateConsumerArn(String streamName) { private String getOrCreateConsumerArn(String streamName) {
if (consumerArn != null) {
return consumerArn;
}
FanOutConsumerRegistration registration = createConsumerRegistration(streamName); FanOutConsumerRegistration registration = createConsumerRegistration(streamName);
try { try {
return registration.getOrCreateStreamConsumerArn(); return registration.getOrCreateStreamConsumerArn();

View file

@ -19,6 +19,7 @@ import lombok.NonNull;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.kinesis.annotations.KinesisClientInternalApi; import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.common.StreamConfig;
import software.amazon.kinesis.common.StreamIdentifier; import software.amazon.kinesis.common.StreamIdentifier;
import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsFactory;
@ -37,8 +38,8 @@ public class FanOutRetrievalFactory implements RetrievalFactory {
private final KinesisAsyncClient kinesisClient; private final KinesisAsyncClient kinesisClient;
private final String defaultStreamName; private final String defaultStreamName;
private final Function<String, String> consumerArnProvider; private final String defaultConsumerName;
private Map<String,String> streamToConsumerArnMap = new HashMap<>(); private final Function<String, String> consumerArnCreator;
@Override @Override
public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo, public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo,
@ -48,19 +49,27 @@ public class FanOutRetrievalFactory implements RetrievalFactory {
@Override @Override
public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo, public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo,
final StreamConfig streamConfig,
final MetricsFactory metricsFactory) { final MetricsFactory metricsFactory) {
final Optional<String> streamIdentifierStr = shardInfo.streamIdentifierSerOpt(); final Optional<String> streamIdentifierStr = shardInfo.streamIdentifierSerOpt();
final String streamName; final String streamName;
if(streamIdentifierStr.isPresent()) { if(streamIdentifierStr.isPresent()) {
streamName = StreamIdentifier.multiStreamInstance(streamIdentifierStr.get()).streamName(); streamName = StreamIdentifier.multiStreamInstance(streamIdentifierStr.get()).streamName();
return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(),
streamToConsumerArnMap.computeIfAbsent(streamName, consumerArnProvider::apply), getOrCreateConsumerArn(streamName, streamConfig.consumerArn()),
streamIdentifierStr.get()); streamIdentifierStr.get());
} else { } else {
streamName = defaultStreamName;
return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(), return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(),
streamToConsumerArnMap.computeIfAbsent(streamName, consumerArnProvider::apply)); getOrCreateConsumerArn(defaultStreamName, defaultConsumerName));
}
} }
@Override
public RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, MetricsFactory metricsFactory) {
throw new UnsupportedOperationException("FanoutRetrievalFactory needs StreamConfig Info");
}
private String getOrCreateConsumerArn(String streamName, String consumerArn) {
return consumerArn != null ? consumerArn : consumerArnCreator.apply(streamName);
} }
} }

View file

@ -191,7 +191,7 @@ public class SchedulerTest {
when(leaseCoordinator.leaseRefresher()).thenReturn(dynamoDBLeaseRefresher); when(leaseCoordinator.leaseRefresher()).thenReturn(dynamoDBLeaseRefresher);
when(shardSyncTaskManager.shardDetector()).thenReturn(shardDetector); when(shardSyncTaskManager.shardDetector()).thenReturn(shardDetector);
when(shardSyncTaskManager.callShardSyncTask()).thenReturn(new TaskResult(null)); when(shardSyncTaskManager.callShardSyncTask()).thenReturn(new TaskResult(null));
when(retrievalFactory.createGetRecordsCache(any(ShardInfo.class), any(MetricsFactory.class))).thenReturn(recordsPublisher); when(retrievalFactory.createGetRecordsCache(any(ShardInfo.class), any(StreamConfig.class), any(MetricsFactory.class))).thenReturn(recordsPublisher);
when(shardDetector.streamIdentifier()).thenReturn(mock(StreamIdentifier.class)); when(shardDetector.streamIdentifier()).thenReturn(mock(StreamIdentifier.class));
scheduler = new Scheduler(checkpointConfig, coordinatorConfig, leaseManagementConfig, lifecycleConfig, scheduler = new Scheduler(checkpointConfig, coordinatorConfig, leaseManagementConfig, lifecycleConfig,

View file

@ -25,12 +25,14 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.kinesis.common.StreamConfig;
import software.amazon.kinesis.leases.ShardInfo; import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.leases.exceptions.DependencyException; import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.metrics.MetricsFactory;
@ -50,6 +52,13 @@ public class FanOutConfigTest {
private FanOutConsumerRegistration consumerRegistration; private FanOutConsumerRegistration consumerRegistration;
@Mock @Mock
private KinesisAsyncClient kinesisClient; private KinesisAsyncClient kinesisClient;
@Mock
private StreamConfig streamConfig;
@Before
public void setup() {
when(streamConfig.consumerArn()).thenReturn(null);
}
@Test @Test
public void testNoRegisterIfConsumerArnSet() throws Exception { public void testNoRegisterIfConsumerArnSet() throws Exception {
@ -68,11 +77,24 @@ public class FanOutConfigTest {
ShardInfo shardInfo = mock(ShardInfo.class); ShardInfo shardInfo = mock(ShardInfo.class);
// doReturn(Optional.of(StreamIdentifier.singleStreamInstance(TEST_STREAM_NAME).serialize())).when(shardInfo).streamIdentifier(); // doReturn(Optional.of(StreamIdentifier.singleStreamInstance(TEST_STREAM_NAME).serialize())).when(shardInfo).streamIdentifier();
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
retrievalFactory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class)); retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(retrievalFactory, not(nullValue())); assertThat(retrievalFactory, not(nullValue()));
verify(consumerRegistration).getOrCreateStreamConsumerArn(); verify(consumerRegistration).getOrCreateStreamConsumerArn();
} }
@Test
public void testRegisterNotCalledWhenConsumerArnSetInMultiStreamMode() throws Exception {
when(streamConfig.consumerArn()).thenReturn("consumerArn");
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
.streamName(TEST_STREAM_NAME);
RetrievalFactory retrievalFactory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt();
retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(retrievalFactory, not(nullValue()));
verify(consumerRegistration, never()).getOrCreateStreamConsumerArn();
}
@Test @Test
public void testDependencyExceptionInConsumerCreation() throws Exception { public void testDependencyExceptionInConsumerCreation() throws Exception {
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME) FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
@ -94,7 +116,7 @@ public class FanOutConfigTest {
RetrievalFactory factory = config.retrievalFactory(); RetrievalFactory factory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class); ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class)); factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(factory, not(nullValue())); assertThat(factory, not(nullValue()));
TestingConfig testingConfig = (TestingConfig) config; TestingConfig testingConfig = (TestingConfig) config;
@ -109,7 +131,7 @@ public class FanOutConfigTest {
RetrievalFactory factory = config.retrievalFactory(); RetrievalFactory factory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class); ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class)); factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(factory, not(nullValue())); assertThat(factory, not(nullValue()));
TestingConfig testingConfig = (TestingConfig) config; TestingConfig testingConfig = (TestingConfig) config;
assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME)); assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME));
@ -123,7 +145,7 @@ public class FanOutConfigTest {
RetrievalFactory factory = config.retrievalFactory(); RetrievalFactory factory = config.retrievalFactory();
ShardInfo shardInfo = mock(ShardInfo.class); ShardInfo shardInfo = mock(ShardInfo.class);
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt(); doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class)); factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
assertThat(factory, not(nullValue())); assertThat(factory, not(nullValue()));
TestingConfig testingConfig = (TestingConfig) config; TestingConfig testingConfig = (TestingConfig) config;