Allow Configuring GetRecords Calls to Timeout. (#214)

It's now possible to configure GetRecords calls to timeout if they
take to long.  This can be used to terminate a long running request to
ensure that record processors continue to make progress

This feature was added with contributions from @pfifer, @sahilpalvia,
and @BtXin.
This commit is contained in:
Sahil Palvia 2017-09-18 10:53:08 -07:00 committed by Pfifer, Justin
parent 656b17ceaa
commit 244da44d29
14 changed files with 764 additions and 21 deletions

View file

@ -0,0 +1,126 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper;
import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingScope;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import lombok.NonNull;
import lombok.extern.apachecommons.CommonsLog;
/**
*
*/
@CommonsLog
public class AsynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy {
private static final int TIME_TO_KEEP_ALIVE = 5;
private static final int CORE_THREAD_POOL_COUNT = 1;
private final KinesisDataFetcher dataFetcher;
private final ExecutorService executorService;
private final int retryGetRecordsInSeconds;
private final String shardId;
final CompletionService<GetRecordsResult> completionService;
public AsynchronousGetRecordsRetrievalStrategy(@NonNull final KinesisDataFetcher dataFetcher,
final int retryGetRecordsInSeconds, final int maxGetRecordsThreadPool, String shardId) {
this(dataFetcher, buildExector(maxGetRecordsThreadPool, shardId), retryGetRecordsInSeconds, shardId);
}
public AsynchronousGetRecordsRetrievalStrategy(final KinesisDataFetcher dataFetcher,
final ExecutorService executorService, final int retryGetRecordsInSeconds, String shardId) {
this(dataFetcher, executorService, retryGetRecordsInSeconds, new ExecutorCompletionService<>(executorService),
shardId);
}
AsynchronousGetRecordsRetrievalStrategy(KinesisDataFetcher dataFetcher, ExecutorService executorService,
int retryGetRecordsInSeconds, CompletionService<GetRecordsResult> completionService, String shardId) {
this.dataFetcher = dataFetcher;
this.executorService = executorService;
this.retryGetRecordsInSeconds = retryGetRecordsInSeconds;
this.completionService = completionService;
this.shardId = shardId;
}
@Override
public GetRecordsResult getRecords(final int maxRecords) {
if (executorService.isShutdown()) {
throw new IllegalStateException("Strategy has been shutdown");
}
GetRecordsResult result = null;
Set<Future<GetRecordsResult>> futures = new HashSet<>();
Callable<GetRecordsResult> retrieverCall = createRetrieverCallable(maxRecords);
while (true) {
try {
futures.add(completionService.submit(retrieverCall));
} catch (RejectedExecutionException e) {
log.warn("Out of resources, unable to start additional requests.");
}
try {
Future<GetRecordsResult> resultFuture = completionService.poll(retryGetRecordsInSeconds,
TimeUnit.SECONDS);
if (resultFuture != null) {
result = resultFuture.get();
break;
}
} catch (ExecutionException e) {
log.error("ExecutionException thrown while trying to get records", e);
} catch (InterruptedException e) {
log.error("Thread was interrupted", e);
break;
}
}
futures.stream().peek(f -> f.cancel(true)).filter(Future::isCancelled).forEach(f -> {
try {
completionService.take();
} catch (InterruptedException e) {
log.error("Exception thrown while trying to empty the threadpool.");
}
});
return result;
}
private Callable<GetRecordsResult> createRetrieverCallable(int maxRecords) {
ThreadSafeMetricsDelegatingScope metricsScope = new ThreadSafeMetricsDelegatingScope(MetricsHelper.getMetricsScope());
return () -> {
try {
MetricsHelper.setMetricsScope(metricsScope);
return dataFetcher.getRecords(maxRecords);
} finally {
MetricsHelper.unsetMetricsScope();
}
};
}
@Override
public void shutdown() {
executorService.shutdownNow();
}
@Override
public boolean isShutdown() {
return executorService.isShutdown();
}
private static ExecutorService buildExector(int maxGetRecordsThreadPool, String shardId) {
String threadNameFormat = "get-records-worker-" + shardId + "-%d";
return new ThreadPoolExecutor(CORE_THREAD_POOL_COUNT, maxGetRecordsThreadPool, TIME_TO_KEEP_ALIVE,
TimeUnit.SECONDS, new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadNameFormat).build(),
new ThreadPoolExecutor.AbortPolicy());
}
}

View file

@ -14,6 +14,8 @@
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.Optional;
/**
* Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks,
* and state transitions is contained within the {@link ConsumerState} objects.
@ -309,7 +311,8 @@ class ConsumerStates {
public ITask createTask(ShardConsumer consumer) {
return new ProcessTask(consumer.getShardInfo(), consumer.getStreamConfig(), consumer.getRecordProcessor(),
consumer.getRecordProcessorCheckpointer(), consumer.getDataFetcher(),
consumer.getTaskBackoffTimeMillis(), consumer.isSkipShardSyncAtWorkerInitializationIfLeasesExist());
consumer.getTaskBackoffTimeMillis(), consumer.isSkipShardSyncAtWorkerInitializationIfLeasesExist(),
consumer.getRetryGetRecordsInSeconds(), consumer.getMaxGetRecordsThreadPool());
}
@Override

View file

@ -0,0 +1,33 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
* Represents a strategy to retrieve records from Kinesis. Allows for variations on how records are retrieved from
* Kinesis.
*/
public interface GetRecordsRetrievalStrategy {
/**
* Gets a set of records from Kinesis.
*
* @param maxRecords
* passed to Kinesis, and can be used to restrict the number of records returned from Kinesis.
* @return the resulting records.
* @throws IllegalStateException
* if the strategy has been shutdown.
*/
GetRecordsResult getRecords(int maxRecords);
/**
* Releases any resources used by the strategy. Once the strategy is shutdown it is no longer safe to call
* {@link #getRecords(int)}.
*/
void shutdown();
/**
* Returns whether this strategy has been shutdown.
*
* @return true if the strategy has been shutdown, false otherwise.
*/
boolean isShutdown();
}

View file

@ -217,6 +217,12 @@ public class KinesisClientLibConfiguration {
@Getter
private Optional<Integer> timeoutInSeconds = Optional.empty();
@Getter
private Optional<Integer> retryGetRecordsInSeconds = Optional.empty();
@Getter
private Optional<Integer> maxGetRecordsThreadPool = Optional.empty();
@Getter
private int maxLeaseRenewalThreads = DEFAULT_MAX_LEASE_RENEWAL_THREADS;
@ -1111,6 +1117,27 @@ public class KinesisClientLibConfiguration {
return this;
}
/**
* @param retryGetRecordsInSeconds the time in seconds to wait before the worker retries to get a record.
* @return this configuration object.
*/
public KinesisClientLibConfiguration withRetryGetRecordsInSeconds(final int retryGetRecordsInSeconds) {
checkIsValuePositive("retryGetRecordsInSeconds", retryGetRecordsInSeconds);
this.retryGetRecordsInSeconds = Optional.of(retryGetRecordsInSeconds);
return this;
}
/**
*@param maxGetRecordsThreadPool the max number of threads in the getRecords thread pool.
*@return this configuration object
*/
public KinesisClientLibConfiguration withMaxGetRecordsThreadPool(final int maxGetRecordsThreadPool) {
checkIsValuePositive("maxGetRecordsThreadPool", maxGetRecordsThreadPool);
this.maxGetRecordsThreadPool = Optional.of(maxGetRecordsThreadPool);
return this;
}
/**
* @param timeoutInSeconds The timeout in seconds to wait for the MultiLangProtocol to wait for
*/

View file

@ -18,6 +18,7 @@ import java.math.BigInteger;
import java.util.Collections;
import java.util.List;
import java.util.ListIterator;
import java.util.Optional;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -62,6 +63,19 @@ class ProcessTask implements ITask {
private final Shard shard;
private final ThrottlingReporter throttlingReporter;
private final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
private static final GetRecordsRetrievalStrategy makeStrategy(KinesisDataFetcher dataFetcher,
Optional<Integer> retryGetRecordsInSeconds,
Optional<Integer> maxGetRecordsThreadPool,
ShardInfo shardInfo) {
Optional<GetRecordsRetrievalStrategy> getRecordsRetrievalStrategy = retryGetRecordsInSeconds.flatMap(retry ->
maxGetRecordsThreadPool.map(max ->
new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, retry, max, shardInfo.getShardId())));
return getRecordsRetrievalStrategy.orElse(new SynchronousGetRecordsRetrievalStrategy(dataFetcher));
}
/**
* @param shardInfo
* contains information about the shard
@ -77,11 +91,38 @@ class ProcessTask implements ITask {
* backoff time when catching exceptions
*/
public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor,
RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher,
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) {
RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher,
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) {
this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist, Optional.empty(), Optional.empty());
}
/**
* @param shardInfo
* contains information about the shard
* @param streamConfig
* Stream configuration
* @param recordProcessor
* Record processor used to process the data records for the shard
* @param recordProcessorCheckpointer
* Passed to the RecordProcessor so it can checkpoint progress
* @param dataFetcher
* Kinesis data fetcher (used to fetch records from Kinesis)
* @param backoffTimeMillis
* backoff time when catching exceptions
* @param retryGetRecordsInSeconds
* time in seconds to wait before the worker retries to get a record.
* @param maxGetRecordsThreadPool
* max number of threads in the getRecords thread pool.
*/
public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor,
RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher,
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist,
Optional<Integer> retryGetRecordsInSeconds, Optional<Integer> maxGetRecordsThreadPool) {
this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist,
new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()));
new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()),
makeStrategy(dataFetcher, retryGetRecordsInSeconds, maxGetRecordsThreadPool, shardInfo));
}
/**
@ -103,7 +144,7 @@ class ProcessTask implements ITask {
public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor,
RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher,
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist,
ThrottlingReporter throttlingReporter) {
ThrottlingReporter throttlingReporter, GetRecordsRetrievalStrategy getRecordsRetrievalStrategy) {
super();
this.shardInfo = shardInfo;
this.recordProcessor = recordProcessor;
@ -113,6 +154,7 @@ class ProcessTask implements ITask {
this.backoffTimeMillis = backoffTimeMillis;
this.throttlingReporter = throttlingReporter;
IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy();
this.getRecordsRetrievalStrategy = getRecordsRetrievalStrategy;
// If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for
// this ProcessTask. In this case, duplicate KPL user records in the event of resharding will
// not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if
@ -368,7 +410,7 @@ class ProcessTask implements ITask {
* @return list of data records from Kinesis
*/
private GetRecordsResult getRecordsResultAndRecordMillisBehindLatest() {
final GetRecordsResult getRecordsResult = dataFetcher.getRecords(streamConfig.getMaxRecords());
final GetRecordsResult getRecordsResult = getRecordsRetrievalStrategy.getRecords(streamConfig.getMaxRecords());
if (getRecordsResult == null) {
// Stream no longer exists

View file

@ -15,10 +15,12 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import lombok.Getter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -53,6 +55,10 @@ class ShardConsumer {
private final boolean cleanupLeasesOfCompletedShards;
private final long taskBackoffTimeMillis;
private final boolean skipShardSyncAtWorkerInitializationIfLeasesExist;
@Getter
private final Optional<Integer> retryGetRecordsInSeconds;
@Getter
private final Optional<Integer> maxGetRecordsThreadPool;
private ITask currentTask;
private long currentTaskSubmitTime;
@ -93,6 +99,38 @@ class ShardConsumer {
IMetricsFactory metricsFactory,
long backoffTimeMillis,
boolean skipShardSyncAtWorkerInitializationIfLeasesExist) {
this(shardInfo, streamConfig, checkpoint,recordProcessor, leaseManager, parentShardPollIntervalMillis,
cleanupLeasesOfCompletedShards, executorService, metricsFactory, backoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist, Optional.empty(), Optional.empty());
}
/**
* @param shardInfo Shard information
* @param streamConfig Stream configuration to use
* @param checkpoint Checkpoint tracker
* @param recordProcessor Record processor used to process the data records for the shard
* @param leaseManager Used to create leases for new shards
* @param parentShardPollIntervalMillis Wait for this long if parent shards are not done (or we get an exception)
* @param executorService ExecutorService used to execute process tasks for this shard
* @param metricsFactory IMetricsFactory used to construct IMetricsScopes for this shard
* @param backoffTimeMillis backoff interval when we encounter exceptions
* @param retryGetRecordsInSeconds time in seconds to wait before the worker retries to get a record.
* @param maxGetRecordsThreadPool max number of threads in the getRecords thread pool.
*/
// CHECKSTYLE:IGNORE ParameterNumber FOR NEXT 10 LINES
ShardConsumer(ShardInfo shardInfo,
StreamConfig streamConfig,
ICheckpoint checkpoint,
IRecordProcessor recordProcessor,
ILeaseManager<KinesisClientLease> leaseManager,
long parentShardPollIntervalMillis,
boolean cleanupLeasesOfCompletedShards,
ExecutorService executorService,
IMetricsFactory metricsFactory,
long backoffTimeMillis,
boolean skipShardSyncAtWorkerInitializationIfLeasesExist,
Optional<Integer> retryGetRecordsInSeconds,
Optional<Integer> maxGetRecordsThreadPool) {
this.streamConfig = streamConfig;
this.recordProcessor = recordProcessor;
this.executorService = executorService;
@ -111,6 +149,8 @@ class ShardConsumer {
this.cleanupLeasesOfCompletedShards = cleanupLeasesOfCompletedShards;
this.taskBackoffTimeMillis = backoffTimeMillis;
this.skipShardSyncAtWorkerInitializationIfLeasesExist = skipShardSyncAtWorkerInitializationIfLeasesExist;
this.retryGetRecordsInSeconds = retryGetRecordsInSeconds;
this.maxGetRecordsThreadPool = maxGetRecordsThreadPool;
}
/**

View file

@ -0,0 +1,31 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import lombok.Data;
import lombok.NonNull;
/**
*
*/
@Data
public class SynchronousGetRecordsRetrievalStrategy implements GetRecordsRetrievalStrategy {
@NonNull
private final KinesisDataFetcher dataFetcher;
@Override
public GetRecordsResult getRecords(final int maxRecords) {
return dataFetcher.getRecords(maxRecords);
}
@Override
public void shutdown() {
//
// Does nothing as this retriever doesn't manage any resources
//
}
@Override
public boolean isShutdown() {
return false;
}
}

View file

@ -17,6 +17,7 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
@ -85,6 +86,9 @@ public class Worker implements Runnable {
private final long taskBackoffTimeMillis;
private final long failoverTimeMillis;
private final Optional<Integer> retryGetRecordsInSeconds;
private final Optional<Integer> maxGetRecordsThreadPool;
// private final KinesisClientLeaseManager leaseManager;
private final KinesisClientLibLeaseCoordinator leaseCoordinator;
private final ShardSyncTaskManager controlServer;
@ -266,7 +270,9 @@ public class Worker implements Runnable {
config.getTaskBackoffTimeMillis(),
config.getFailoverTimeMillis(),
config.getSkipShardSyncAtWorkerInitializationIfLeasesExist(),
config.getShardPrioritizationStrategy());
config.getShardPrioritizationStrategy(),
config.getRetryGetRecordsInSeconds(),
config.getMaxGetRecordsThreadPool());
// If a region name was explicitly specified, use it as the region for Amazon Kinesis and Amazon DynamoDB.
if (config.getRegionName() != null) {
@ -333,6 +339,56 @@ public class Worker implements Runnable {
KinesisClientLibLeaseCoordinator leaseCoordinator, ExecutorService execService,
IMetricsFactory metricsFactory, long taskBackoffTimeMillis, long failoverTimeMillis,
boolean skipShardSyncAtWorkerInitializationIfLeasesExist, ShardPrioritization shardPrioritization) {
this(applicationName, recordProcessorFactory, streamConfig, initialPositionInStream, parentShardPollIntervalMillis,
shardSyncIdleTimeMillis, cleanupLeasesUponShardCompletion, checkpoint, leaseCoordinator, execService,
metricsFactory, taskBackoffTimeMillis, failoverTimeMillis, skipShardSyncAtWorkerInitializationIfLeasesExist,
shardPrioritization, Optional.empty(), Optional.empty());
}
/**
* @param applicationName
* Name of the Kinesis application
* @param recordProcessorFactory
* Used to get record processor instances for processing data from shards
* @param streamConfig
* Stream configuration
* @param initialPositionInStream
* One of LATEST, TRIM_HORIZON, or AT_TIMESTAMP. The KinesisClientLibrary will start fetching data from
* this location in the stream when an application starts up for the first time and there are no
* checkpoints. If there are checkpoints, we start from the checkpoint position.
* @param parentShardPollIntervalMillis
* Wait for this long between polls to check if parent shards are done
* @param shardSyncIdleTimeMillis
* Time between tasks to sync leases and Kinesis shards
* @param cleanupLeasesUponShardCompletion
* Clean up shards we've finished processing (don't wait till they expire in Kinesis)
* @param checkpoint
* Used to get/set checkpoints
* @param leaseCoordinator
* Lease coordinator (coordinates currently owned leases)
* @param execService
* ExecutorService to use for processing records (support for multi-threaded consumption)
* @param metricsFactory
* Metrics factory used to emit metrics
* @param taskBackoffTimeMillis
* Backoff period when tasks encounter an exception
* @param shardPrioritization
* Provides prioritization logic to decide which available shards process first
* @param retryGetRecordsInSeconds
* Time in seconds to wait before the worker retries to get a record.
* @param maxGetRecordsThreadPool
* Max number of threads in the getRecords thread pool.
*/
// NOTE: This has package level access solely for testing
// CHECKSTYLE:IGNORE ParameterNumber FOR NEXT 10 LINES
Worker(String applicationName, IRecordProcessorFactory recordProcessorFactory, StreamConfig streamConfig,
InitialPositionInStreamExtended initialPositionInStream, long parentShardPollIntervalMillis,
long shardSyncIdleTimeMillis, boolean cleanupLeasesUponShardCompletion, ICheckpoint checkpoint,
KinesisClientLibLeaseCoordinator leaseCoordinator, ExecutorService execService,
IMetricsFactory metricsFactory, long taskBackoffTimeMillis, long failoverTimeMillis,
boolean skipShardSyncAtWorkerInitializationIfLeasesExist, ShardPrioritization shardPrioritization,
Optional<Integer> retryGetRecordsInSeconds, Optional<Integer> maxGetRecordsThreadPool) {
this.applicationName = applicationName;
this.recordProcessorFactory = recordProcessorFactory;
this.streamConfig = streamConfig;
@ -351,8 +407,11 @@ public class Worker implements Runnable {
this.failoverTimeMillis = failoverTimeMillis;
this.skipShardSyncAtWorkerInitializationIfLeasesExist = skipShardSyncAtWorkerInitializationIfLeasesExist;
this.shardPrioritization = shardPrioritization;
this.retryGetRecordsInSeconds = retryGetRecordsInSeconds;
this.maxGetRecordsThreadPool = maxGetRecordsThreadPool;
}
/**
* @return the applicationName
*/
@ -786,7 +845,7 @@ public class Worker implements Runnable {
return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor,
leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion,
executorService, metricsFactory, taskBackoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist);
skipShardSyncAtWorkerInitializationIfLeasesExist, retryGetRecordsInSeconds, maxGetRecordsThreadPool);
}
@ -1213,7 +1272,9 @@ public class Worker implements Runnable {
config.getTaskBackoffTimeMillis(),
config.getFailoverTimeMillis(),
config.getSkipShardSyncAtWorkerInitializationIfLeasesExist(),
shardPrioritization);
shardPrioritization,
config.getRetryGetRecordsInSeconds(),
config.getMaxGetRecordsThreadPool());
}

View file

@ -22,6 +22,7 @@ import static org.junit.Assert.fail;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.Optional;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
@ -60,6 +61,8 @@ public class KinesisClientLibConfiguratorTest {
assertEquals(config.getApplicationName(), "b");
assertEquals(config.getStreamName(), "a");
assertEquals(config.getWorkerIdentifier(), "123");
assertEquals(config.getMaxGetRecordsThreadPool(), Optional.empty());
assertEquals(config.getRetryGetRecordsInSeconds(), Optional.empty());
}
@Test
@ -107,7 +110,9 @@ public class KinesisClientLibConfiguratorTest {
"workerId = w123",
"maxRecords = 10",
"metricsMaxQueueSize = 20",
"applicationName = kinesis"
"applicationName = kinesis",
"retryGetRecordsInSeconds = 2",
"maxGetRecordsThreadPool = 1"
}, '\n'));
assertEquals(config.getApplicationName(), "kinesis");
@ -115,6 +120,8 @@ public class KinesisClientLibConfiguratorTest {
assertEquals(config.getWorkerIdentifier(), "w123");
assertEquals(config.getMaxRecords(), 10);
assertEquals(config.getMetricsMaxQueueSize(), 20);
assertEquals(config.getRetryGetRecordsInSeconds(), Optional.of(2));
assertEquals(config.getMaxGetRecordsThreadPool(), Optional.of(1));
}
@Test
@ -202,6 +209,42 @@ public class KinesisClientLibConfiguratorTest {
assertEquals(config.getInitialPositionInStream(), InitialPositionInStream.TRIM_HORIZON);
}
@Test
public void testEmptyOptionalVariables() {
KinesisClientLibConfiguration config =
getConfiguration(StringUtils.join(new String[] {
"streamName = a",
"applicationName = b",
"AWSCredentialsProvider = ABCD," + credentialName1,
"workerId = 123",
"initialPositionInStream = TriM_Horizon",
"maxGetRecordsThreadPool = 1"
}, '\n'));
assertEquals(config.getMaxGetRecordsThreadPool(), Optional.of(1));
assertEquals(config.getRetryGetRecordsInSeconds(), Optional.empty());
}
@Test
public void testWithZeroValue() {
String test = StringUtils.join(new String[]{
"streamName = a",
"applicationName = b",
"AWSCredentialsProvider = ABCD," + credentialName1,
"workerId = 123",
"initialPositionInStream = TriM_Horizon",
"maxGetRecordsThreadPool = 0",
"retryGetRecordsInSeconds = 0"
}, '\n');
InputStream input = new ByteArrayInputStream(test.getBytes());
try {
configurator.getConfiguration(input);
} catch (Exception e) {
fail("Don't expect to fail on invalid variable value");
}
}
@Test
public void testWithInvalidIntValue() {
String test = StringUtils.join(new String[] {

View file

@ -0,0 +1,156 @@
/*
* Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.clientlibrary.proxies.IKinesisProxy;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AsynchronousGetRecordsRetrievalStrategyIntegrationTest {
private static final int CORE_POOL_SIZE = 1;
private static final int MAX_POOL_SIZE = 2;
private static final int TIME_TO_LIVE = 5;
private static final int RETRY_GET_RECORDS_IN_SECONDS = 2;
private static final int SLEEP_GET_RECORDS_IN_SECONDS = 10;
@Mock
private IKinesisProxy mockKinesisProxy;
@Mock
private ShardInfo mockShardInfo;
private AsynchronousGetRecordsRetrievalStrategy getRecordsRetrivalStrategy;
private KinesisDataFetcher dataFetcher;
private GetRecordsResult result;
private ExecutorService executorService;
private RejectedExecutionHandler rejectedExecutionHandler;
private int numberOfRecords = 10;
private CompletionService<GetRecordsResult> completionService;
@Before
public void setup() {
dataFetcher = spy(new KinesisDataFetcherForTests(mockKinesisProxy, mockShardInfo));
rejectedExecutionHandler = spy(new ThreadPoolExecutor.AbortPolicy());
executorService = spy(new ThreadPoolExecutor(
CORE_POOL_SIZE,
MAX_POOL_SIZE,
TIME_TO_LIVE,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1),
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("getrecords-worker-%d").build(),
rejectedExecutionHandler));
getRecordsRetrivalStrategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher, executorService, RETRY_GET_RECORDS_IN_SECONDS, "shardId-0001");
completionService = spy(getRecordsRetrivalStrategy.completionService);
result = null;
}
@Test
public void oneRequestMultithreadTest() {
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(eq(numberOfRecords));
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
assertNull(getRecordsResult);
}
@Test
public void multiRequestTest() {
result = mock(GetRecordsResult.class);
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(dataFetcher, atLeast(getLeastNumberOfCalls())).getRecords(numberOfRecords);
verify(executorService, atLeast(getLeastNumberOfCalls())).execute(any());
assertEquals(result, getRecordsResult);
result = null;
getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
assertNull(getRecordsResult);
}
@Test
@Ignore
public void testInterrupted() throws InterruptedException, ExecutionException {
Future<GetRecordsResult> mockFuture = mock(Future.class);
when(completionService.submit(any())).thenReturn(mockFuture);
when(completionService.poll()).thenReturn(mockFuture);
doThrow(InterruptedException.class).when(mockFuture).get();
GetRecordsResult getRecordsResult = getRecordsRetrivalStrategy.getRecords(numberOfRecords);
verify(mockFuture).get();
assertNull(getRecordsResult);
}
private int getLeastNumberOfCalls() {
int leastNumberOfCalls = 0;
for (int i = MAX_POOL_SIZE; i > 0; i--) {
if (i * RETRY_GET_RECORDS_IN_SECONDS <= SLEEP_GET_RECORDS_IN_SECONDS) {
leastNumberOfCalls = i;
break;
}
}
return leastNumberOfCalls;
}
@After
public void shutdown() {
getRecordsRetrivalStrategy.shutdown();
verify(executorService).shutdownNow();
}
private class KinesisDataFetcherForTests extends KinesisDataFetcher {
public KinesisDataFetcherForTests(final IKinesisProxy kinesisProxy, final ShardInfo shardInfo) {
super(kinesisProxy, shardInfo);
}
@Override
public GetRecordsResult getRecords(final int maxRecords) {
try {
Thread.sleep(SLEEP_GET_RECORDS_IN_SECONDS * 1000);
} catch (InterruptedException e) {
// Do nothing
}
return result;
}
}
}

View file

@ -0,0 +1,137 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
*
*/
@RunWith(MockitoJUnitRunner.class)
public class AsynchronousGetRecordsRetrievalStrategyTest {
private static final long RETRY_GET_RECORDS_IN_SECONDS = 5;
private static final String SHARD_ID = "ShardId-0001";
@Mock
private KinesisDataFetcher dataFetcher;
@Mock
private ExecutorService executorService;
@Mock
private CompletionService<GetRecordsResult> completionService;
@Mock
private Future<GetRecordsResult> successfulFuture;
@Mock
private Future<GetRecordsResult> blockedFuture;
@Mock
private GetRecordsResult expectedResults;
@Test
public void testSingleSuccessfulRequestFuture() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
GetRecordsResult result = strategy.getRecords(10);
verify(executorService).isShutdown();
verify(completionService).submit(any());
verify(completionService).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).get();
verify(successfulFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(completionService, never()).take();
assertThat(result, equalTo(expectedResults));
}
@Test
public void testBlockedAndSuccessfulFuture() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false);
when(blockedFuture.isCancelled()).thenReturn(true);
GetRecordsResult actualResults = strategy.getRecords(10);
verify(completionService, times(2)).submit(any());
verify(completionService, times(2)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).get();
verify(blockedFuture, never()).get();
verify(successfulFuture).cancel(eq(true));
verify(blockedFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(blockedFuture).isCancelled();
verify(completionService).take();
assertThat(actualResults, equalTo(expectedResults));
}
@Test(expected = IllegalStateException.class)
public void testStrategyIsShutdown() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(true);
strategy.getRecords(10);
}
@Test
public void testPoolOutOfResources() throws Exception {
AsynchronousGetRecordsRetrievalStrategy strategy = new AsynchronousGetRecordsRetrievalStrategy(dataFetcher,
executorService, (int) RETRY_GET_RECORDS_IN_SECONDS, completionService, SHARD_ID);
when(executorService.isShutdown()).thenReturn(false);
when(completionService.submit(any())).thenReturn(blockedFuture).thenThrow(new RejectedExecutionException("Rejected!")).thenReturn(successfulFuture);
when(completionService.poll(anyLong(), any())).thenReturn(null).thenReturn(null).thenReturn(successfulFuture);
when(successfulFuture.get()).thenReturn(expectedResults);
when(successfulFuture.cancel(anyBoolean())).thenReturn(false);
when(blockedFuture.cancel(anyBoolean())).thenReturn(true);
when(successfulFuture.isCancelled()).thenReturn(false);
when(blockedFuture.isCancelled()).thenReturn(true);
GetRecordsResult actualResult = strategy.getRecords(10);
verify(completionService, times(3)).submit(any());
verify(completionService, times(3)).poll(eq(RETRY_GET_RECORDS_IN_SECONDS), eq(TimeUnit.SECONDS));
verify(successfulFuture).cancel(eq(true));
verify(blockedFuture).cancel(eq(true));
verify(successfulFuture).isCancelled();
verify(blockedFuture).isCancelled();
verify(completionService).take();
assertThat(actualResult, equalTo(expectedResults));
}
}

View file

@ -17,6 +17,7 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static com.amazonaws.services.kinesis.clientlibrary.lib.worker.ConsumerStates.ConsumerState;
import static com.amazonaws.services.kinesis.clientlibrary.lib.worker.ConsumerStates.ShardConsumerState;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.never;
@ -25,6 +26,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.lang.reflect.Field;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
@ -152,7 +154,10 @@ public class ConsumerStatesTest {
}
@Test
public void processingStateTest() {
public void processingStateTestSynchronous() {
when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.empty());
when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.empty());
ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState();
ITask task = state.createTask(consumer);
@ -163,6 +168,38 @@ public class ConsumerStatesTest {
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(SynchronousGetRecordsRetrievalStrategy.class) ));
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.ZOMBIE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.TERMINATE),
equalTo(ShardConsumerState.SHUTTING_DOWN.getConsumerState()));
assertThat(state.shutdownTransition(ShutdownReason.REQUESTED),
equalTo(ShardConsumerState.SHUTDOWN_REQUESTED.getConsumerState()));
assertThat(state.getState(), equalTo(ShardConsumerState.PROCESSING));
assertThat(state.getTaskType(), equalTo(TaskType.PROCESS));
}
@Test
public void processingStateTestAsynchronous() {
when(consumer.getMaxGetRecordsThreadPool()).thenReturn(Optional.of(1));
when(consumer.getRetryGetRecordsInSeconds()).thenReturn(Optional.of(2));
ConsumerState state = ShardConsumerState.PROCESSING.getConsumerState();
ITask task = state.createTask(consumer);
assertThat(task, procTask(ShardInfo.class, "shardInfo", equalTo(shardInfo)));
assertThat(task, procTask(IRecordProcessor.class, "recordProcessor", equalTo(recordProcessor)));
assertThat(task, procTask(RecordProcessorCheckpointer.class, "recordProcessorCheckpointer",
equalTo(recordProcessorCheckpointer)));
assertThat(task, procTask(KinesisDataFetcher.class, "dataFetcher", equalTo(dataFetcher)));
assertThat(task, procTask(StreamConfig.class, "streamConfig", equalTo(streamConfig)));
assertThat(task, procTask(Long.class, "backoffTimeMillis", equalTo(taskBackoffTimeMillis)));
assertThat(task, procTask(GetRecordsRetrievalStrategy.class, "getRecordsRetrievalStrategy", instanceOf(AsynchronousGetRecordsRetrievalStrategy.class) ));
assertThat(state.successTransition(), equalTo(ShardConsumerState.PROCESSING.getConsumerState()));

View file

@ -117,6 +117,7 @@ public class KinesisDataFetcherTest {
ICheckpoint checkpoint = mock(ICheckpoint.class);
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO);
GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher);
String iteratorA = "foo";
String iteratorB = "bar";
@ -138,10 +139,10 @@ public class KinesisDataFetcherTest {
fetcher.initialize(seqA, null);
fetcher.advanceIteratorTo(seqA, null);
Assert.assertEquals(recordsA, fetcher.getRecords(MAX_RECORDS).getRecords());
Assert.assertEquals(recordsA, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords());
fetcher.advanceIteratorTo(seqB, null);
Assert.assertEquals(recordsB, fetcher.getRecords(MAX_RECORDS).getRecords());
Assert.assertEquals(recordsB, getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords());
}
@Test
@ -181,8 +182,9 @@ public class KinesisDataFetcherTest {
// Create data fectcher and initialize it with latest type checkpoint
KinesisDataFetcher dataFetcher = new KinesisDataFetcher(mockProxy, SHARD_INFO);
dataFetcher.initialize(SentinelCheckpoint.LATEST.toString(), INITIAL_POSITION_LATEST);
GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(dataFetcher);
// Call getRecords of dataFetcher which will throw an exception
dataFetcher.getRecords(maxRecords);
getRecordsRetrievalStrategy.getRecords(maxRecords);
// Test shard has reached the end
Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached());
@ -206,8 +208,9 @@ public class KinesisDataFetcherTest {
when(checkpoint.getCheckpoint(SHARD_ID)).thenReturn(new ExtendedSequenceNumber(seqNo));
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO);
GetRecordsRetrievalStrategy getRecordsRetrievalStrategy = new SynchronousGetRecordsRetrievalStrategy(fetcher);
fetcher.initialize(seqNo, initialPositionInStream);
List<Record> actualRecords = fetcher.getRecords(MAX_RECORDS).getRecords();
List<Record> actualRecords = getRecordsRetrievalStrategy.getRecords(MAX_RECORDS).getRecords();
Assert.assertEquals(expectedRecords, actualRecords);
}

View file

@ -19,7 +19,7 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
@ -76,6 +76,8 @@ public class ProcessTaskTest {
private @Mock RecordProcessorCheckpointer mockCheckpointer;
@Mock
private ThrottlingReporter throttlingReporter;
@Mock
private GetRecordsRetrievalStrategy mockGetRecordsRetrievalStrategy;
private List<Record> processedRecords;
private ExtendedSequenceNumber newLargestPermittedCheckpointValue;
@ -94,19 +96,20 @@ public class ProcessTaskTest {
final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null);
processTask = new ProcessTask(
shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter);
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsRetrievalStrategy);
}
@Test
public void testProcessTaskWithProvisionedThroughputExceededException() {
// Set data fetcher to throw exception
doReturn(false).when(mockDataFetcher).isShardEndReached();
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockDataFetcher)
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsRetrievalStrategy)
.getRecords(maxRecords);
TaskResult result = processTask.call();
verify(throttlingReporter).throttled();
verify(throttlingReporter, never()).success();
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
assertTrue("Result should contain ProvisionedThroughputExceededException",
result.getException() instanceof ProvisionedThroughputExceededException);
}
@ -114,9 +117,10 @@ public class ProcessTaskTest {
@Test
public void testProcessTaskWithNonExistentStream() {
// Data fetcher returns a null Result when the stream does not exist
doReturn(null).when(mockDataFetcher).getRecords(maxRecords);
doReturn(null).when(mockGetRecordsRetrievalStrategy).getRecords(maxRecords);
TaskResult result = processTask.call();
verify(mockGetRecordsRetrievalStrategy).getRecords(eq(maxRecords));
assertNull("Task should not throw an exception", result.getException());
}
@ -300,14 +304,14 @@ public class ProcessTaskTest {
private void testWithRecords(List<Record> records,
ExtendedSequenceNumber lastCheckpointValue,
ExtendedSequenceNumber largestPermittedCheckpointValue) {
when(mockDataFetcher.getRecords(anyInt())).thenReturn(
when(mockGetRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(
new GetRecordsResult().withRecords(records));
when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue);
when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue);
processTask.call();
verify(throttlingReporter).success();
verify(throttlingReporter, never()).throttled();
verify(mockGetRecordsRetrievalStrategy).getRecords(anyInt());
ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class);
verify(mockRecordProcessor).processRecords(priCaptor.capture());
processedRecords = priCaptor.getValue().getRecords();