Adding dedicated consumerArn support for streams in multistreaming mode
This commit is contained in:
parent
113029e33c
commit
f69398a2b2
9 changed files with 83 additions and 29 deletions
|
|
@ -15,14 +15,15 @@
|
|||
|
||||
package software.amazon.kinesis.common;
|
||||
|
||||
import lombok.Value;
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
@Value
|
||||
@Data
|
||||
@Accessors(fluent = true)
|
||||
public class StreamConfig {
|
||||
StreamIdentifier streamIdentifier;
|
||||
InitialPositionInStreamExtended initialPositionInStreamExtended;
|
||||
private final StreamIdentifier streamIdentifier;
|
||||
private final InitialPositionInStreamExtended initialPositionInStreamExtended;
|
||||
private String consumerArn;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ public class StreamIdentifier {
|
|||
|
||||
/**
|
||||
* Create a multi stream instance for StreamIdentifier from serialized stream identifier.
|
||||
* The serialized stream identifier should be of the format account:stream:creationepoch
|
||||
* @param streamIdentifierSer
|
||||
* @return StreamIdentifier
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -891,7 +891,6 @@ public class Scheduler implements Runnable {
|
|||
|
||||
protected ShardConsumer buildConsumer(@NonNull final ShardInfo shardInfo,
|
||||
@NonNull final ShardRecordProcessorFactory shardRecordProcessorFactory) {
|
||||
RecordsPublisher cache = retrievalConfig.retrievalFactory().createGetRecordsCache(shardInfo, metricsFactory);
|
||||
ShardRecordProcessorCheckpointer checkpointer = coordinatorConfig.coordinatorFactory().createRecordProcessorCheckpointer(shardInfo,
|
||||
checkpoint);
|
||||
// The only case where streamName is not available will be when multistreamtracker not set. In this case,
|
||||
|
|
@ -902,6 +901,7 @@ public class Scheduler implements Runnable {
|
|||
// to gracefully complete the reading.
|
||||
final StreamConfig streamConfig = currentStreamConfigMap.getOrDefault(streamIdentifier, getDefaultStreamConfig(streamIdentifier));
|
||||
Validate.notNull(streamConfig, "StreamConfig should not be null");
|
||||
RecordsPublisher cache = retrievalConfig.retrievalFactory().createGetRecordsCache(shardInfo, streamConfig, metricsFactory);
|
||||
ShardConsumerArgument argument = new ShardConsumerArgument(shardInfo,
|
||||
streamConfig.streamIdentifier(),
|
||||
leaseCoordinator,
|
||||
|
|
|
|||
|
|
@ -121,6 +121,13 @@ public class RetrievalConfig {
|
|||
return this;
|
||||
}
|
||||
|
||||
public RetrievalConfig retrievalSpecificConfig(RetrievalSpecificConfig retrievalSpecificConfig) {
|
||||
this.retrievalSpecificConfig = retrievalSpecificConfig;
|
||||
validateFanoutConfig();
|
||||
validatePollingConfig();
|
||||
return this;
|
||||
}
|
||||
|
||||
public RetrievalFactory retrievalFactory() {
|
||||
if (retrievalFactory == null) {
|
||||
if (retrievalSpecificConfig == null) {
|
||||
|
|
@ -129,22 +136,36 @@ public class RetrievalConfig {
|
|||
retrievalSpecificConfig = appStreamTracker.map(multiStreamTracker -> retrievalSpecificConfig,
|
||||
streamConfig -> ((FanOutConfig) retrievalSpecificConfig).streamName(streamConfig.streamIdentifier().streamName()));
|
||||
}
|
||||
|
||||
retrievalFactory = retrievalSpecificConfig.retrievalFactory();
|
||||
}
|
||||
validateConfig();
|
||||
return retrievalFactory;
|
||||
}
|
||||
|
||||
private void validateConfig() {
|
||||
private void validateFanoutConfig() {
|
||||
// If we are in multistream mode and if retrievalSpecificConfig is an instance of FanOutConfig and if consumerArn is set throw exception.
|
||||
boolean isFanoutConfig = retrievalSpecificConfig instanceof FanOutConfig;
|
||||
boolean isInvalidFanoutConfig = isFanoutConfig && appStreamTracker.map(
|
||||
multiStreamTracker -> ((FanOutConfig) retrievalSpecificConfig).consumerArn() != null
|
||||
|| ((FanOutConfig) retrievalSpecificConfig).streamName() != null,
|
||||
streamConfig -> streamConfig.streamIdentifier() == null
|
||||
|| streamConfig.streamIdentifier().streamName() == null);
|
||||
if(isInvalidFanoutConfig) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid config: Either in multi-stream mode with streamName/consumerArn configured or in single-stream mode with no streamName configured");
|
||||
}
|
||||
}
|
||||
|
||||
private void validatePollingConfig() {
|
||||
boolean isPollingConfig = retrievalSpecificConfig instanceof PollingConfig;
|
||||
boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map(multiStreamTracker ->
|
||||
boolean isInvalidPollingConfig = isPollingConfig && appStreamTracker.map(
|
||||
multiStreamTracker ->
|
||||
((PollingConfig) retrievalSpecificConfig).streamName() != null,
|
||||
streamConfig ->
|
||||
streamConfig.streamIdentifier() == null || streamConfig.streamIdentifier().streamName() == null);
|
||||
|
||||
if(isInvalidPollingConfig) {
|
||||
throw new IllegalArgumentException("Invalid config: multistream enabled with streamName or single stream with no streamName");
|
||||
if (isInvalidPollingConfig) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid config: Either in multi-stream mode with streamName configured or in single-stream mode with no streamName configured");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
package software.amazon.kinesis.retrieval;
|
||||
|
||||
import software.amazon.kinesis.common.StreamConfig;
|
||||
import software.amazon.kinesis.leases.ShardInfo;
|
||||
import software.amazon.kinesis.metrics.MetricsFactory;
|
||||
|
||||
|
|
@ -24,5 +25,10 @@ import software.amazon.kinesis.metrics.MetricsFactory;
|
|||
public interface RetrievalFactory {
|
||||
GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(ShardInfo shardInfo, MetricsFactory metricsFactory);
|
||||
|
||||
@Deprecated
|
||||
RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, MetricsFactory metricsFactory);
|
||||
|
||||
default RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, StreamConfig streamConfig, MetricsFactory metricsFactory) {
|
||||
return createGetRecordsCache(shardInfo, metricsFactory);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,17 +80,11 @@ public class FanOutConfig implements RetrievalSpecificConfig {
|
|||
*/
|
||||
private long retryBackoffMillis = 1000;
|
||||
|
||||
@Override
|
||||
public RetrievalFactory retrievalFactory() {
|
||||
return new FanOutRetrievalFactory(kinesisClient, streamName, this::getOrCreateConsumerArn);
|
||||
@Override public RetrievalFactory retrievalFactory() {
|
||||
return new FanOutRetrievalFactory(kinesisClient, streamName, consumerArn, this::getOrCreateConsumerArn);
|
||||
}
|
||||
|
||||
// TODO : LTR. Need Stream Specific ConsumerArn to be passed from Customer
|
||||
private String getOrCreateConsumerArn(String streamName) {
|
||||
if (consumerArn != null) {
|
||||
return consumerArn;
|
||||
}
|
||||
|
||||
FanOutConsumerRegistration registration = createConsumerRegistration(streamName);
|
||||
try {
|
||||
return registration.getOrCreateStreamConsumerArn();
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import lombok.NonNull;
|
|||
import lombok.RequiredArgsConstructor;
|
||||
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
|
||||
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
|
||||
import software.amazon.kinesis.common.StreamConfig;
|
||||
import software.amazon.kinesis.common.StreamIdentifier;
|
||||
import software.amazon.kinesis.leases.ShardInfo;
|
||||
import software.amazon.kinesis.metrics.MetricsFactory;
|
||||
|
|
@ -37,8 +38,8 @@ public class FanOutRetrievalFactory implements RetrievalFactory {
|
|||
|
||||
private final KinesisAsyncClient kinesisClient;
|
||||
private final String defaultStreamName;
|
||||
private final Function<String, String> consumerArnProvider;
|
||||
private Map<String,String> streamToConsumerArnMap = new HashMap<>();
|
||||
private final String defaultConsumerName;
|
||||
private final Function<String, String> consumerArnCreator;
|
||||
|
||||
@Override
|
||||
public GetRecordsRetrievalStrategy createGetRecordsRetrievalStrategy(final ShardInfo shardInfo,
|
||||
|
|
@ -48,19 +49,27 @@ public class FanOutRetrievalFactory implements RetrievalFactory {
|
|||
|
||||
@Override
|
||||
public RecordsPublisher createGetRecordsCache(@NonNull final ShardInfo shardInfo,
|
||||
final StreamConfig streamConfig,
|
||||
final MetricsFactory metricsFactory) {
|
||||
final Optional<String> streamIdentifierStr = shardInfo.streamIdentifierSerOpt();
|
||||
final String streamName;
|
||||
if(streamIdentifierStr.isPresent()) {
|
||||
streamName = StreamIdentifier.multiStreamInstance(streamIdentifierStr.get()).streamName();
|
||||
return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(),
|
||||
streamToConsumerArnMap.computeIfAbsent(streamName, consumerArnProvider::apply),
|
||||
getOrCreateConsumerArn(streamName, streamConfig.consumerArn()),
|
||||
streamIdentifierStr.get());
|
||||
} else {
|
||||
streamName = defaultStreamName;
|
||||
return new FanOutRecordsPublisher(kinesisClient, shardInfo.shardId(),
|
||||
streamToConsumerArnMap.computeIfAbsent(streamName, consumerArnProvider::apply));
|
||||
getOrCreateConsumerArn(defaultStreamName, defaultConsumerName));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public RecordsPublisher createGetRecordsCache(ShardInfo shardInfo, MetricsFactory metricsFactory) {
|
||||
throw new UnsupportedOperationException("FanoutRetrievalFactory needs StreamConfig Info");
|
||||
}
|
||||
|
||||
private String getOrCreateConsumerArn(String streamName, String consumerArn) {
|
||||
return consumerArn != null ? consumerArn : consumerArnCreator.apply(streamName);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ public class SchedulerTest {
|
|||
when(leaseCoordinator.leaseRefresher()).thenReturn(dynamoDBLeaseRefresher);
|
||||
when(shardSyncTaskManager.shardDetector()).thenReturn(shardDetector);
|
||||
when(shardSyncTaskManager.callShardSyncTask()).thenReturn(new TaskResult(null));
|
||||
when(retrievalFactory.createGetRecordsCache(any(ShardInfo.class), any(MetricsFactory.class))).thenReturn(recordsPublisher);
|
||||
when(retrievalFactory.createGetRecordsCache(any(ShardInfo.class), any(StreamConfig.class), any(MetricsFactory.class))).thenReturn(recordsPublisher);
|
||||
when(shardDetector.streamIdentifier()).thenReturn(mock(StreamIdentifier.class));
|
||||
|
||||
scheduler = new Scheduler(checkpointConfig, coordinatorConfig, leaseManagementConfig, lifecycleConfig,
|
||||
|
|
|
|||
|
|
@ -25,12 +25,14 @@ import static org.mockito.Mockito.never;
|
|||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.runners.MockitoJUnitRunner;
|
||||
|
||||
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
|
||||
import software.amazon.kinesis.common.StreamConfig;
|
||||
import software.amazon.kinesis.leases.ShardInfo;
|
||||
import software.amazon.kinesis.leases.exceptions.DependencyException;
|
||||
import software.amazon.kinesis.metrics.MetricsFactory;
|
||||
|
|
@ -50,6 +52,13 @@ public class FanOutConfigTest {
|
|||
private FanOutConsumerRegistration consumerRegistration;
|
||||
@Mock
|
||||
private KinesisAsyncClient kinesisClient;
|
||||
@Mock
|
||||
private StreamConfig streamConfig;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
when(streamConfig.consumerArn()).thenReturn(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoRegisterIfConsumerArnSet() throws Exception {
|
||||
|
|
@ -68,11 +77,24 @@ public class FanOutConfigTest {
|
|||
ShardInfo shardInfo = mock(ShardInfo.class);
|
||||
// doReturn(Optional.of(StreamIdentifier.singleStreamInstance(TEST_STREAM_NAME).serialize())).when(shardInfo).streamIdentifier();
|
||||
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
|
||||
retrievalFactory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class));
|
||||
retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
|
||||
assertThat(retrievalFactory, not(nullValue()));
|
||||
verify(consumerRegistration).getOrCreateStreamConsumerArn();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRegisterNotCalledWhenConsumerArnSetInMultiStreamMode() throws Exception {
|
||||
when(streamConfig.consumerArn()).thenReturn("consumerArn");
|
||||
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
|
||||
.streamName(TEST_STREAM_NAME);
|
||||
RetrievalFactory retrievalFactory = config.retrievalFactory();
|
||||
ShardInfo shardInfo = mock(ShardInfo.class);
|
||||
doReturn(Optional.of("account:stream:12345")).when(shardInfo).streamIdentifierSerOpt();
|
||||
retrievalFactory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
|
||||
assertThat(retrievalFactory, not(nullValue()));
|
||||
verify(consumerRegistration, never()).getOrCreateStreamConsumerArn();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDependencyExceptionInConsumerCreation() throws Exception {
|
||||
FanOutConfig config = new TestingConfig(kinesisClient).applicationName(TEST_APPLICATION_NAME)
|
||||
|
|
@ -94,7 +116,7 @@ public class FanOutConfigTest {
|
|||
RetrievalFactory factory = config.retrievalFactory();
|
||||
ShardInfo shardInfo = mock(ShardInfo.class);
|
||||
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
|
||||
factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class));
|
||||
factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
|
||||
assertThat(factory, not(nullValue()));
|
||||
|
||||
TestingConfig testingConfig = (TestingConfig) config;
|
||||
|
|
@ -109,7 +131,7 @@ public class FanOutConfigTest {
|
|||
RetrievalFactory factory = config.retrievalFactory();
|
||||
ShardInfo shardInfo = mock(ShardInfo.class);
|
||||
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
|
||||
factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class));
|
||||
factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
|
||||
assertThat(factory, not(nullValue()));
|
||||
TestingConfig testingConfig = (TestingConfig) config;
|
||||
assertThat(testingConfig.stream, equalTo(TEST_STREAM_NAME));
|
||||
|
|
@ -123,7 +145,7 @@ public class FanOutConfigTest {
|
|||
RetrievalFactory factory = config.retrievalFactory();
|
||||
ShardInfo shardInfo = mock(ShardInfo.class);
|
||||
doReturn(Optional.empty()).when(shardInfo).streamIdentifierSerOpt();
|
||||
factory.createGetRecordsCache(shardInfo, mock(MetricsFactory.class));
|
||||
factory.createGetRecordsCache(shardInfo, streamConfig, mock(MetricsFactory.class));
|
||||
assertThat(factory, not(nullValue()));
|
||||
|
||||
TestingConfig testingConfig = (TestingConfig) config;
|
||||
|
|
|
|||
Loading…
Reference in a new issue