Capability of restarting the subscription from last processed batch (#492)
* Started to play around with restarting from last processed After a failure the KCL would instead restart from what the ShardConsumer says it last processed. * Extracted the InternalSubscriber to its own class Extracted the InternalSubscriber to ShardConsumerSubscriber to make testing easier. Added tests for the ShardConsumerSubscriber that verifies error handling and other components of the class. Added tests that verify the restart from behavior. * Moved the ProcessRecordsInputMatcher to its own class * Initial changes to for restarting of the PrefetchRecordsPublisher * Remove code coverage configuration * Switched to using explicit locks to deal with blocking queue When the blocking queue is full it would normally enter into a fully parked state, but would continue to hold the lock. This changes the process to only block for a second when attempting to enqueue a response, and if it doesn't succeed check to see if it's been reset before attempting again. * Changed locking around the restart, and how fetcher gets updated Changed the locking around the restart to use a reader/writer lock instead of single lock with a yield. Changed how the fetcher is reset to not restart from an advanceIteratorTo which would retrieve a new shard iterator. Instead the resetIterator method takes both the iterator to start from, the last accepted sequence number, and the initial position. * Changed test to ensure that PositionResetException is thrown Changed the test to wait for the queue to reach capacity before restarting the PrefetchRecordsPublisher. This should mostly ensure that calling restartFrom will trigger a throw of a PositionResetException. Added @VisibleFortest on the queue since it was already being used in testing. * Move to snapshot * Ensure that only one thread can be sending data at a time In the test the TestPublisher is accessed from two threads: the test thread, and the dispatch thread. Both have the possibility of calling send() under certain conditions. This changes it so that only one of the threads can actively be sending data at a time. TestPublisher#requested was changed to volatile to ensure that calling cancel can correctly set it to zero. * Block the test until the blocking thread is in control This test is somewhat of an odd case as it intends to test what happens when nothing is dispatched to the ShardConsumerSubcriber for some amount of time, but data is queued for dispatch. To do this we block the single thread of the executor with a lock to ensure that items pile up in the queue so that should the restart work incorrectly we will see lost data.
This commit is contained in:
parent
b2751f09d5
commit
fd0cb8e98f
14 changed files with 1098 additions and 207 deletions
|
|
@ -20,17 +20,12 @@ import java.util.Optional;
|
|||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.reactivestreams.Subscriber;
|
||||
import org.reactivestreams.Subscription;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
|
||||
import io.reactivex.Flowable;
|
||||
import io.reactivex.Scheduler;
|
||||
import io.reactivex.schedulers.Schedulers;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
|
|
@ -44,7 +39,6 @@ import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
|
|||
import software.amazon.kinesis.metrics.MetricsCollectingTaskDecorator;
|
||||
import software.amazon.kinesis.metrics.MetricsFactory;
|
||||
import software.amazon.kinesis.retrieval.RecordsPublisher;
|
||||
import software.amazon.kinesis.retrieval.RetryableRetrievalException;
|
||||
|
||||
/**
|
||||
* Responsible for consuming data records of a (specified) shard.
|
||||
|
|
@ -60,7 +54,6 @@ public class ShardConsumer {
|
|||
public static final int MAX_TIME_BETWEEN_REQUEST_RESPONSE = 35000;
|
||||
private final RecordsPublisher recordsPublisher;
|
||||
private final ExecutorService executorService;
|
||||
private final Scheduler scheduler;
|
||||
private final ShardInfo shardInfo;
|
||||
private final ShardConsumerArgument shardConsumerArgument;
|
||||
@NonNull
|
||||
|
|
@ -72,9 +65,6 @@ public class ShardConsumer {
|
|||
private ConsumerTask currentTask;
|
||||
private TaskOutcome taskOutcome;
|
||||
|
||||
private final AtomicReference<Throwable> processFailure = new AtomicReference<>(null);
|
||||
private final AtomicReference<Throwable> dispatchFailure = new AtomicReference<>(null);
|
||||
|
||||
private CompletableFuture<Boolean> stateChangeFuture;
|
||||
private boolean needsInitialization = true;
|
||||
|
||||
|
|
@ -94,7 +84,7 @@ public class ShardConsumer {
|
|||
private volatile ShutdownReason shutdownReason;
|
||||
private volatile ShutdownNotification shutdownNotification;
|
||||
|
||||
private final InternalSubscriber subscriber;
|
||||
private final ShardConsumerSubscriber subscriber;
|
||||
|
||||
public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo,
|
||||
Optional<Long> logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument,
|
||||
|
|
@ -119,8 +109,7 @@ public class ShardConsumer {
|
|||
this.taskExecutionListener = taskExecutionListener;
|
||||
this.currentState = initialState;
|
||||
this.taskMetricsDecorator = taskMetricsDecorator;
|
||||
scheduler = Schedulers.from(executorService);
|
||||
subscriber = new InternalSubscriber();
|
||||
subscriber = new ShardConsumerSubscriber(recordsPublisher, executorService, bufferSize, this);
|
||||
this.bufferSize = bufferSize;
|
||||
|
||||
if (this.shardInfo.isCompleted()) {
|
||||
|
|
@ -128,64 +117,8 @@ public class ShardConsumer {
|
|||
}
|
||||
}
|
||||
|
||||
private void startSubscriptions() {
|
||||
Flowable.fromPublisher(recordsPublisher).subscribeOn(scheduler).observeOn(scheduler, true, bufferSize)
|
||||
.subscribe(subscriber);
|
||||
}
|
||||
|
||||
private final Object lockObject = new Object();
|
||||
private Instant lastRequestTime = null;
|
||||
|
||||
private class InternalSubscriber implements Subscriber<ProcessRecordsInput> {
|
||||
|
||||
private Subscription subscription;
|
||||
private volatile Instant lastDataArrival;
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Subscription s) {
|
||||
subscription = s;
|
||||
subscription.request(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(ProcessRecordsInput input) {
|
||||
try {
|
||||
synchronized (lockObject) {
|
||||
lastRequestTime = null;
|
||||
}
|
||||
lastDataArrival = Instant.now();
|
||||
handleInput(input.toBuilder().cacheExitTime(Instant.now()).build(), subscription);
|
||||
} catch (Throwable t) {
|
||||
log.warn("{}: Caught exception from handleInput", shardInfo.shardId(), t);
|
||||
dispatchFailure.set(t);
|
||||
} finally {
|
||||
subscription.request(1);
|
||||
synchronized (lockObject) {
|
||||
lastRequestTime = Instant.now();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
log.warn("{}: onError(). Cancelling subscription, and marking self as failed.", shardInfo.shardId(), t);
|
||||
subscription.cancel();
|
||||
processFailure.set(t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
log.debug("{}: onComplete(): Received onComplete. Activity should be triggered externally", shardInfo.shardId());
|
||||
}
|
||||
|
||||
public void cancel() {
|
||||
if (subscription != null) {
|
||||
subscription.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) {
|
||||
synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) {
|
||||
if (isShutdownRequested()) {
|
||||
subscription.cancel();
|
||||
return;
|
||||
|
|
@ -240,50 +173,15 @@ public class ShardConsumer {
|
|||
Throwable healthCheck() {
|
||||
logNoDataRetrievedAfterTime();
|
||||
logLongRunningTask();
|
||||
Throwable failure = processFailure.get();
|
||||
if (!processFailure.compareAndSet(failure, null) && failure != null) {
|
||||
log.error("{}: processFailure was updated while resetting, this shouldn't happen. " +
|
||||
"Will retry on next health check", shardInfo.shardId());
|
||||
return null;
|
||||
}
|
||||
Throwable failure = subscriber.healthCheck(MAX_TIME_BETWEEN_REQUEST_RESPONSE);
|
||||
|
||||
if (failure != null) {
|
||||
String logMessage = String.format("%s: Failure occurred in retrieval. Restarting data requests", shardInfo.shardId());
|
||||
if (failure instanceof RetryableRetrievalException) {
|
||||
log.debug(logMessage, failure.getCause());
|
||||
} else {
|
||||
log.warn(logMessage, failure);
|
||||
}
|
||||
startSubscriptions();
|
||||
return failure;
|
||||
}
|
||||
Throwable expectedDispatchFailure = dispatchFailure.get();
|
||||
if (expectedDispatchFailure != null) {
|
||||
if (!dispatchFailure.compareAndSet(expectedDispatchFailure, null)) {
|
||||
log.info("{}: Unable to reset the dispatch failure, this can happen if the record processor is failing aggressively.", shardInfo.shardId());
|
||||
return null;
|
||||
}
|
||||
log.warn("Exception occurred while dispatching incoming data. The incoming data has been skipped", expectedDispatchFailure);
|
||||
return expectedDispatchFailure;
|
||||
}
|
||||
synchronized (lockObject) {
|
||||
if (lastRequestTime != null) {
|
||||
Instant now = Instant.now();
|
||||
Duration timeSinceLastResponse = Duration.between(lastRequestTime, now);
|
||||
if (timeSinceLastResponse.toMillis() > MAX_TIME_BETWEEN_REQUEST_RESPONSE) {
|
||||
log.error(
|
||||
"{}: Last request was dispatched at {}, but no response as of {} ({}). Cancelling subscription, and restarting.",
|
||||
shardInfo.shardId(), lastRequestTime, now, timeSinceLastResponse);
|
||||
if (subscriber != null) {
|
||||
subscriber.cancel();
|
||||
}
|
||||
//
|
||||
// Set the last request time to now, we specifically don't null it out since we want it to trigger a
|
||||
// restart if the subscription still doesn't start producing.
|
||||
//
|
||||
lastRequestTime = Instant.now();
|
||||
startSubscriptions();
|
||||
}
|
||||
}
|
||||
Throwable dispatchFailure = subscriber.getAndResetDispatchFailure();
|
||||
if (dispatchFailure != null) {
|
||||
log.warn("Exception occurred while dispatching incoming data. The incoming data has been skipped", dispatchFailure);
|
||||
return dispatchFailure;
|
||||
}
|
||||
|
||||
return null;
|
||||
|
|
@ -306,10 +204,10 @@ public class ShardConsumer {
|
|||
|
||||
private void logNoDataRetrievedAfterTime() {
|
||||
logWarningForTaskAfterMillis.ifPresent(value -> {
|
||||
Instant lastDataArrival = subscriber.lastDataArrival;
|
||||
Instant lastDataArrival = subscriber.lastDataArrival();
|
||||
if (lastDataArrival != null) {
|
||||
Instant now = Instant.now();
|
||||
Duration timeSince = Duration.between(subscriber.lastDataArrival, now);
|
||||
Duration timeSince = Duration.between(subscriber.lastDataArrival(), now);
|
||||
if (timeSince.toMillis() > value) {
|
||||
log.warn("Last time data arrived: {} ({})", lastDataArrival, timeSince);
|
||||
}
|
||||
|
|
@ -335,7 +233,7 @@ public class ShardConsumer {
|
|||
|
||||
@VisibleForTesting
|
||||
void subscribe() {
|
||||
startSubscriptions();
|
||||
subscriber.startSubscriptions();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
|
|
|||
|
|
@ -0,0 +1,183 @@
|
|||
/*
|
||||
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Amazon Software License (the "License").
|
||||
* You may not use this file except in compliance with the License.
|
||||
* A copy of the License is located at
|
||||
*
|
||||
* http://aws.amazon.com/asl/
|
||||
*
|
||||
* or in the "license" file accompanying this file. This file is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
* express or implied. See the License for the specific language governing
|
||||
* permissions and limitations under the License.
|
||||
*/
|
||||
package software.amazon.kinesis.lifecycle;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.reactivex.Flowable;
|
||||
import io.reactivex.Scheduler;
|
||||
import io.reactivex.schedulers.Schedulers;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.experimental.Accessors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.reactivestreams.Subscriber;
|
||||
import org.reactivestreams.Subscription;
|
||||
import software.amazon.kinesis.retrieval.RecordsPublisher;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.RetryableRetrievalException;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
@Slf4j
|
||||
@Accessors(fluent = true)
|
||||
class ShardConsumerSubscriber implements Subscriber<RecordsRetrieved> {
|
||||
|
||||
private final RecordsPublisher recordsPublisher;
|
||||
private final Scheduler scheduler;
|
||||
private final int bufferSize;
|
||||
private final ShardConsumer shardConsumer;
|
||||
|
||||
@VisibleForTesting
|
||||
final Object lockObject = new Object();
|
||||
private Instant lastRequestTime = null;
|
||||
private RecordsRetrieved lastAccepted = null;
|
||||
|
||||
private Subscription subscription;
|
||||
@Getter
|
||||
private volatile Instant lastDataArrival;
|
||||
@Getter
|
||||
private volatile Throwable dispatchFailure;
|
||||
@Getter(AccessLevel.PACKAGE)
|
||||
private volatile Throwable retrievalFailure;
|
||||
|
||||
ShardConsumerSubscriber(RecordsPublisher recordsPublisher, ExecutorService executorService, int bufferSize,
|
||||
ShardConsumer shardConsumer) {
|
||||
this.recordsPublisher = recordsPublisher;
|
||||
this.scheduler = Schedulers.from(executorService);
|
||||
this.bufferSize = bufferSize;
|
||||
this.shardConsumer = shardConsumer;
|
||||
}
|
||||
|
||||
void startSubscriptions() {
|
||||
synchronized (lockObject) {
|
||||
if (lastAccepted != null) {
|
||||
recordsPublisher.restartFrom(lastAccepted);
|
||||
}
|
||||
Flowable.fromPublisher(recordsPublisher).subscribeOn(scheduler).observeOn(scheduler, true, bufferSize)
|
||||
.subscribe(this);
|
||||
}
|
||||
}
|
||||
|
||||
Throwable healthCheck(long maxTimeBetweenRequests) {
|
||||
Throwable result = restartIfFailed();
|
||||
if (result == null) {
|
||||
restartIfRequestTimerExpired(maxTimeBetweenRequests);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Throwable getAndResetDispatchFailure() {
|
||||
synchronized (lockObject) {
|
||||
Throwable failure = dispatchFailure;
|
||||
dispatchFailure = null;
|
||||
return failure;
|
||||
}
|
||||
}
|
||||
|
||||
private Throwable restartIfFailed() {
|
||||
Throwable oldFailure = null;
|
||||
if (retrievalFailure != null) {
|
||||
synchronized (lockObject) {
|
||||
String logMessage = String.format("%s: Failure occurred in retrieval. Restarting data requests", shardConsumer.shardInfo().shardId());
|
||||
if (retrievalFailure instanceof RetryableRetrievalException) {
|
||||
log.debug(logMessage, retrievalFailure.getCause());
|
||||
} else {
|
||||
log.warn(logMessage, retrievalFailure);
|
||||
}
|
||||
oldFailure = retrievalFailure;
|
||||
retrievalFailure = null;
|
||||
}
|
||||
startSubscriptions();
|
||||
}
|
||||
|
||||
return oldFailure;
|
||||
}
|
||||
|
||||
private void restartIfRequestTimerExpired(long maxTimeBetweenRequests) {
|
||||
synchronized (lockObject) {
|
||||
if (lastRequestTime != null) {
|
||||
Instant now = Instant.now();
|
||||
Duration timeSinceLastResponse = Duration.between(lastRequestTime, now);
|
||||
if (timeSinceLastResponse.toMillis() > maxTimeBetweenRequests) {
|
||||
log.error(
|
||||
"{}: Last request was dispatched at {}, but no response as of {} ({}). Cancelling subscription, and restarting.",
|
||||
shardConsumer.shardInfo().shardId(), lastRequestTime, now, timeSinceLastResponse);
|
||||
cancel();
|
||||
//
|
||||
// Set the last request time to now, we specifically don't null it out since we want it to
|
||||
// trigger a
|
||||
// restart if the subscription still doesn't start producing.
|
||||
//
|
||||
lastRequestTime = Instant.now();
|
||||
startSubscriptions();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Subscription s) {
|
||||
subscription = s;
|
||||
subscription.request(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(RecordsRetrieved input) {
|
||||
try {
|
||||
synchronized (lockObject) {
|
||||
lastRequestTime = null;
|
||||
}
|
||||
lastDataArrival = Instant.now();
|
||||
shardConsumer.handleInput(input.processRecordsInput().toBuilder().cacheExitTime(Instant.now()).build(),
|
||||
subscription);
|
||||
|
||||
} catch (Throwable t) {
|
||||
log.warn("{}: Caught exception from handleInput", shardConsumer.shardInfo().shardId(), t);
|
||||
synchronized (lockObject) {
|
||||
dispatchFailure = t;
|
||||
}
|
||||
} finally {
|
||||
subscription.request(1);
|
||||
synchronized (lockObject) {
|
||||
lastAccepted = input;
|
||||
lastRequestTime = Instant.now();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
synchronized (lockObject) {
|
||||
log.warn("{}: onError(). Cancelling subscription, and marking self as failed.",
|
||||
shardConsumer.shardInfo().shardId(), t);
|
||||
subscription.cancel();
|
||||
retrievalFailure = t;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
log.debug("{}: onComplete(): Received onComplete. Activity should be triggered externally",
|
||||
shardConsumer.shardInfo().shardId());
|
||||
}
|
||||
|
||||
public void cancel() {
|
||||
if (subscription != null) {
|
||||
subscription.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -24,7 +24,7 @@ import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
|||
/**
|
||||
* Provides a record publisher that will retrieve records from Kinesis for processing
|
||||
*/
|
||||
public interface RecordsPublisher extends Publisher<ProcessRecordsInput> {
|
||||
public interface RecordsPublisher extends Publisher<RecordsRetrieved> {
|
||||
/**
|
||||
* Initializes the publisher with where to start processing. If there is a stored sequence number the publisher will
|
||||
* begin from that sequence number, otherwise it will use the initial position.
|
||||
|
|
@ -35,6 +35,12 @@ public interface RecordsPublisher extends Publisher<ProcessRecordsInput> {
|
|||
* if there is no sequence number the initial position to use
|
||||
*/
|
||||
void start(ExtendedSequenceNumber extendedSequenceNumber, InitialPositionInStreamExtended initialPositionInStreamExtended);
|
||||
|
||||
/**
|
||||
* Restart from the last accepted and processed
|
||||
* @param recordsRetrieved the processRecordsInput to restart from
|
||||
*/
|
||||
void restartFrom(RecordsRetrieved recordsRetrieved);
|
||||
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Amazon Software License (the "License").
|
||||
* You may not use this file except in compliance with the License.
|
||||
* A copy of the License is located at
|
||||
*
|
||||
* http://aws.amazon.com/asl/
|
||||
*
|
||||
* or in the "license" file accompanying this file. This file is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
* express or implied. See the License for the specific language governing
|
||||
* permissions and limitations under the License.
|
||||
*/
|
||||
package software.amazon.kinesis.retrieval;
|
||||
|
||||
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
||||
|
||||
public interface RecordsRetrieved {
|
||||
|
||||
/**
|
||||
* Retrieves the records that have been received via one of the publishers
|
||||
*
|
||||
* @return the processRecordsInput received
|
||||
*/
|
||||
ProcessRecordsInput processRecordsInput();
|
||||
}
|
||||
|
|
@ -24,8 +24,10 @@ import java.util.stream.Collectors;
|
|||
import org.reactivestreams.Subscriber;
|
||||
import org.reactivestreams.Subscription;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.experimental.Accessors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import software.amazon.awssdk.core.async.SdkPublisher;
|
||||
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
|
||||
|
|
@ -42,6 +44,7 @@ import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
|||
import software.amazon.kinesis.retrieval.IteratorBuilder;
|
||||
import software.amazon.kinesis.retrieval.KinesisClientRecord;
|
||||
import software.amazon.kinesis.retrieval.RecordsPublisher;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.RetryableRetrievalException;
|
||||
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
||||
|
||||
|
|
@ -67,7 +70,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
|
|||
private InitialPositionInStreamExtended initialPositionInStreamExtended;
|
||||
private boolean isFirstConnection = true;
|
||||
|
||||
private Subscriber<? super ProcessRecordsInput> subscriber;
|
||||
private Subscriber<? super RecordsRetrieved> subscriber;
|
||||
private long availableQueueSpace = 0;
|
||||
|
||||
@Override
|
||||
|
|
@ -93,6 +96,24 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void restartFrom(RecordsRetrieved recordsRetrieved) {
|
||||
synchronized (lockObject) {
|
||||
if (flow != null) {
|
||||
//
|
||||
// The flow should not be running at this time
|
||||
//
|
||||
flow.cancel();
|
||||
}
|
||||
flow = null;
|
||||
if (!(recordsRetrieved instanceof FanoutRecordsRetrieved)) {
|
||||
throw new IllegalArgumentException(
|
||||
"Provided ProcessRecordsInput not created from the FanOutRecordsPublisher");
|
||||
}
|
||||
currentSequenceNumber = ((FanoutRecordsRetrieved) recordsRetrieved).continuationSequenceNumber();
|
||||
}
|
||||
}
|
||||
|
||||
private boolean hasValidSubscriber() {
|
||||
return subscriber != null;
|
||||
}
|
||||
|
|
@ -174,8 +195,10 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
|
|||
log.debug(
|
||||
"{}: Could not call SubscribeToShard successfully because shard no longer exists. Marking shard for completion.",
|
||||
shardId);
|
||||
FanoutRecordsRetrieved response = new FanoutRecordsRetrieved(
|
||||
ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).build(), null);
|
||||
subscriber
|
||||
.onNext(ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).build());
|
||||
.onNext(response);
|
||||
subscriber.onComplete();
|
||||
} else {
|
||||
subscriber.onError(t);
|
||||
|
|
@ -257,9 +280,10 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
|
|||
ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now())
|
||||
.millisBehindLatest(recordBatchEvent.millisBehindLatest())
|
||||
.isAtShardEnd(recordBatchEvent.continuationSequenceNumber() == null).records(records).build();
|
||||
FanoutRecordsRetrieved recordsRetrieved = new FanoutRecordsRetrieved(input, recordBatchEvent.continuationSequenceNumber());
|
||||
|
||||
try {
|
||||
subscriber.onNext(input);
|
||||
subscriber.onNext(recordsRetrieved);
|
||||
//
|
||||
// Only advance the currentSequenceNumber if we successfully dispatch the last received input
|
||||
//
|
||||
|
|
@ -311,7 +335,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void subscribe(Subscriber<? super ProcessRecordsInput> s) {
|
||||
public void subscribe(Subscriber<? super RecordsRetrieved> s) {
|
||||
synchronized (lockObject) {
|
||||
if (subscriber != null) {
|
||||
log.error(
|
||||
|
|
@ -444,6 +468,19 @@ public class FanOutRecordsPublisher implements RecordsPublisher {
|
|||
});
|
||||
}
|
||||
|
||||
@Accessors(fluent = true)
|
||||
@Data
|
||||
static class FanoutRecordsRetrieved implements RecordsRetrieved {
|
||||
|
||||
private final ProcessRecordsInput processRecordsInput;
|
||||
private final String continuationSequenceNumber;
|
||||
|
||||
@Override
|
||||
public ProcessRecordsInput processRecordsInput() {
|
||||
return processRecordsInput;
|
||||
}
|
||||
}
|
||||
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
static class RecordFlow implements SubscribeToShardResponseHandler {
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
|||
import software.amazon.kinesis.retrieval.GetRecordsRetrievalStrategy;
|
||||
import software.amazon.kinesis.retrieval.KinesisClientRecord;
|
||||
import software.amazon.kinesis.retrieval.RecordsPublisher;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
||||
|
||||
/**
|
||||
|
|
@ -38,7 +39,7 @@ public class BlockingRecordsPublisher implements RecordsPublisher {
|
|||
private final int maxRecordsPerCall;
|
||||
private final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy;
|
||||
|
||||
private Subscriber<? super ProcessRecordsInput> subscriber;
|
||||
private Subscriber<? super RecordsRetrieved> subscriber;
|
||||
|
||||
public BlockingRecordsPublisher(final int maxRecordsPerCall,
|
||||
final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy) {
|
||||
|
|
@ -70,7 +71,12 @@ public class BlockingRecordsPublisher implements RecordsPublisher {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void subscribe(Subscriber<? super ProcessRecordsInput> s) {
|
||||
public void subscribe(Subscriber<? super RecordsRetrieved> s) {
|
||||
subscriber = s;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void restartFrom(RecordsRetrieved recordsRetrieved) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -227,6 +227,12 @@ public class KinesisDataFetcher {
|
|||
advanceIteratorTo(lastKnownSequenceNumber, initialPositionInStream);
|
||||
}
|
||||
|
||||
public void resetIterator(String shardIterator, String sequenceNumber, InitialPositionInStreamExtended initialPositionInStream) {
|
||||
this.nextIterator = shardIterator;
|
||||
this.lastKnownSequenceNumber = sequenceNumber;
|
||||
this.initialPositionInStream = initialPositionInStream;
|
||||
}
|
||||
|
||||
private GetRecordsResponse getRecords(@NonNull final String nextIterator) {
|
||||
final AWSExceptionManager exceptionManager = createExceptionManager();
|
||||
GetRecordsRequest request = KinesisRequestsBuilder.getRecordsRequestBuilder().shardIterator(nextIterator)
|
||||
|
|
|
|||
|
|
@ -20,14 +20,19 @@ import java.time.Instant;
|
|||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import org.apache.commons.lang3.Validate;
|
||||
import org.reactivestreams.Subscriber;
|
||||
import org.reactivestreams.Subscription;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import lombok.experimental.Accessors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import software.amazon.awssdk.core.exception.SdkClientException;
|
||||
import software.amazon.awssdk.services.cloudwatch.model.StandardUnit;
|
||||
|
|
@ -44,6 +49,7 @@ import software.amazon.kinesis.metrics.ThreadSafeMetricsDelegatingFactory;
|
|||
import software.amazon.kinesis.retrieval.GetRecordsRetrievalStrategy;
|
||||
import software.amazon.kinesis.retrieval.KinesisClientRecord;
|
||||
import software.amazon.kinesis.retrieval.RecordsPublisher;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
||||
|
||||
/**
|
||||
|
|
@ -58,7 +64,8 @@ import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
|||
@KinesisClientInternalApi
|
||||
public class PrefetchRecordsPublisher implements RecordsPublisher {
|
||||
private static final String EXPIRED_ITERATOR_METRIC = "ExpiredIterator";
|
||||
LinkedBlockingQueue<ProcessRecordsInput> getRecordsResultQueue;
|
||||
@VisibleForTesting
|
||||
LinkedBlockingQueue<PrefetchRecordsRetrieved> getRecordsResultQueue;
|
||||
private int maxPendingProcessRecordsInput;
|
||||
private int maxByteSize;
|
||||
private int maxRecordsCount;
|
||||
|
|
@ -75,9 +82,15 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
private final KinesisDataFetcher dataFetcher;
|
||||
private final String shardId;
|
||||
|
||||
private Subscriber<? super ProcessRecordsInput> subscriber;
|
||||
private Subscriber<? super RecordsRetrieved> subscriber;
|
||||
private final AtomicLong requestedResponses = new AtomicLong(0);
|
||||
|
||||
private String highestSequenceNumber;
|
||||
private InitialPositionInStreamExtended initialPositionInStreamExtended;
|
||||
|
||||
private final ReentrantReadWriteLock resetLock = new ReentrantReadWriteLock();
|
||||
private boolean wasReset = false;
|
||||
|
||||
/**
|
||||
* Constructor for the PrefetchRecordsPublisher. This cache prefetches records from Kinesis and stores them in a
|
||||
* LinkedBlockingQueue.
|
||||
|
|
@ -124,6 +137,8 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
throw new IllegalStateException("ExecutorService has been shutdown.");
|
||||
}
|
||||
|
||||
this.initialPositionInStreamExtended = initialPositionInStreamExtended;
|
||||
highestSequenceNumber = extendedSequenceNumber.sequenceNumber();
|
||||
dataFetcher.initialize(extendedSequenceNumber, initialPositionInStreamExtended);
|
||||
|
||||
if (!started) {
|
||||
|
|
@ -133,7 +148,7 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
started = true;
|
||||
}
|
||||
|
||||
ProcessRecordsInput getNextResult() {
|
||||
RecordsRetrieved getNextResult() {
|
||||
if (executorService.isShutdown()) {
|
||||
throw new IllegalStateException("Shutdown has been called on the cache, can't accept new requests.");
|
||||
}
|
||||
|
|
@ -141,14 +156,16 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
if (!started) {
|
||||
throw new IllegalStateException("Cache has not been initialized, make sure to call start.");
|
||||
}
|
||||
ProcessRecordsInput result = null;
|
||||
PrefetchRecordsRetrieved result = null;
|
||||
try {
|
||||
result = getRecordsResultQueue.take().toBuilder().cacheExitTime(Instant.now()).build();
|
||||
prefetchCounters.removed(result);
|
||||
result = getRecordsResultQueue.take().prepareForPublish();
|
||||
prefetchCounters.removed(result.processRecordsInput);
|
||||
requestedResponses.decrementAndGet();
|
||||
|
||||
} catch (InterruptedException e) {
|
||||
log.error("Interrupted while getting records from the cache", e);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
@ -160,7 +177,28 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void subscribe(Subscriber<? super ProcessRecordsInput> s) {
|
||||
public void restartFrom(RecordsRetrieved recordsRetrieved) {
|
||||
if (!(recordsRetrieved instanceof PrefetchRecordsRetrieved)) {
|
||||
throw new IllegalArgumentException(
|
||||
"Provided RecordsRetrieved was not produced by the PrefetchRecordsPublisher");
|
||||
}
|
||||
PrefetchRecordsRetrieved prefetchRecordsRetrieved = (PrefetchRecordsRetrieved) recordsRetrieved;
|
||||
resetLock.writeLock().lock();
|
||||
try {
|
||||
getRecordsResultQueue.clear();
|
||||
prefetchCounters.reset();
|
||||
|
||||
highestSequenceNumber = prefetchRecordsRetrieved.lastBatchSequenceNumber();
|
||||
dataFetcher.resetIterator(prefetchRecordsRetrieved.shardIterator(), highestSequenceNumber,
|
||||
initialPositionInStreamExtended);
|
||||
wasReset = true;
|
||||
} finally {
|
||||
resetLock.writeLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void subscribe(Subscriber<? super RecordsRetrieved> s) {
|
||||
subscriber = s;
|
||||
subscriber.onSubscribe(new Subscription() {
|
||||
@Override
|
||||
|
|
@ -176,9 +214,22 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
});
|
||||
}
|
||||
|
||||
private void addArrivedRecordsInput(ProcessRecordsInput processRecordsInput) throws InterruptedException {
|
||||
getRecordsResultQueue.put(processRecordsInput);
|
||||
prefetchCounters.added(processRecordsInput);
|
||||
private void addArrivedRecordsInput(PrefetchRecordsRetrieved recordsRetrieved) throws InterruptedException {
|
||||
wasReset = false;
|
||||
while (!getRecordsResultQueue.offer(recordsRetrieved, idleMillisBetweenCalls, TimeUnit.MILLISECONDS)) {
|
||||
//
|
||||
// Unlocking the read lock, and then reacquiring the read lock, should allow any waiters on the write lock a
|
||||
// chance to run. If the write lock is acquired by restartFrom than the readLock will now block until
|
||||
// restartFrom(...) has completed. This is to ensure that if a reset has occurred we know to discard the
|
||||
// data we received and start a new fetch of data.
|
||||
//
|
||||
resetLock.readLock().unlock();
|
||||
resetLock.readLock().lock();
|
||||
if (wasReset) {
|
||||
throw new PositionResetException();
|
||||
}
|
||||
}
|
||||
prefetchCounters.added(recordsRetrieved.processRecordsInput);
|
||||
}
|
||||
|
||||
private synchronized void drainQueueForRequests() {
|
||||
|
|
@ -187,6 +238,34 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
}
|
||||
}
|
||||
|
||||
@Accessors(fluent = true)
|
||||
@Data
|
||||
static class PrefetchRecordsRetrieved implements RecordsRetrieved {
|
||||
|
||||
final ProcessRecordsInput processRecordsInput;
|
||||
final String lastBatchSequenceNumber;
|
||||
final String shardIterator;
|
||||
|
||||
PrefetchRecordsRetrieved prepareForPublish() {
|
||||
return new PrefetchRecordsRetrieved(processRecordsInput.toBuilder().cacheExitTime(Instant.now()).build(),
|
||||
lastBatchSequenceNumber, shardIterator);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private String calculateHighestSequenceNumber(ProcessRecordsInput processRecordsInput) {
|
||||
String result = this.highestSequenceNumber;
|
||||
if (processRecordsInput.records() != null && !processRecordsInput.records().isEmpty()) {
|
||||
result = processRecordsInput.records().get(processRecordsInput.records().size() - 1).sequenceNumber();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static class PositionResetException extends RuntimeException {
|
||||
|
||||
}
|
||||
|
||||
|
||||
private class DefaultGetRecordsCacheDaemon implements Runnable {
|
||||
volatile boolean isShutdown = false;
|
||||
|
||||
|
|
@ -197,57 +276,78 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
log.warn("Prefetch thread was interrupted.");
|
||||
break;
|
||||
}
|
||||
MetricsScope scope = MetricsUtil.createMetricsWithOperation(metricsFactory, operation);
|
||||
if (prefetchCounters.shouldGetNewRecords()) {
|
||||
try {
|
||||
sleepBeforeNextCall();
|
||||
GetRecordsResponse getRecordsResult = getRecordsRetrievalStrategy.getRecords(maxRecordsPerCall);
|
||||
lastSuccessfulCall = Instant.now();
|
||||
|
||||
final List<KinesisClientRecord> records = getRecordsResult.records().stream()
|
||||
.map(KinesisClientRecord::fromRecord).collect(Collectors.toList());
|
||||
ProcessRecordsInput processRecordsInput = ProcessRecordsInput.builder()
|
||||
.records(records)
|
||||
.millisBehindLatest(getRecordsResult.millisBehindLatest())
|
||||
.cacheEntryTime(lastSuccessfulCall)
|
||||
.isAtShardEnd(getRecordsRetrievalStrategy.getDataFetcher().isShardEndReached())
|
||||
.build();
|
||||
addArrivedRecordsInput(processRecordsInput);
|
||||
drainQueueForRequests();
|
||||
} catch (InterruptedException e) {
|
||||
log.info("Thread was interrupted, indicating shutdown was called on the cache.");
|
||||
} catch (ExpiredIteratorException e) {
|
||||
log.info("ShardId {}: records threw ExpiredIteratorException - restarting"
|
||||
+ " after greatest seqNum passed to customer", shardId, e);
|
||||
|
||||
scope.addData(EXPIRED_ITERATOR_METRIC, 1, StandardUnit.COUNT, MetricsLevel.SUMMARY);
|
||||
|
||||
dataFetcher.restartIterator();
|
||||
} catch (SdkClientException e) {
|
||||
log.error("Exception thrown while fetching records from Kinesis", e);
|
||||
} catch (Throwable e) {
|
||||
log.error("Unexpected exception was thrown. This could probably be an issue or a bug." +
|
||||
" Please search for the exception/error online to check what is going on. If the " +
|
||||
"issue persists or is a recurring problem, feel free to open an issue on, " +
|
||||
"https://github.com/awslabs/amazon-kinesis-client.", e);
|
||||
} finally {
|
||||
MetricsUtil.endScope(scope);
|
||||
}
|
||||
} else {
|
||||
//
|
||||
// Consumer isn't ready to receive new records will allow prefetch counters to pause
|
||||
//
|
||||
try {
|
||||
prefetchCounters.waitForConsumer();
|
||||
} catch (InterruptedException ie) {
|
||||
log.info("Thread was interrupted while waiting for the consumer. " +
|
||||
"Shutdown has probably been started");
|
||||
}
|
||||
resetLock.readLock().lock();
|
||||
try {
|
||||
makeRetrievalAttempt();
|
||||
} catch(PositionResetException pre) {
|
||||
log.debug("Position was reset while attempting to add item to queue.");
|
||||
} finally {
|
||||
resetLock.readLock().unlock();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
callShutdownOnStrategy();
|
||||
}
|
||||
|
||||
private void makeRetrievalAttempt() {
|
||||
MetricsScope scope = MetricsUtil.createMetricsWithOperation(metricsFactory, operation);
|
||||
if (prefetchCounters.shouldGetNewRecords()) {
|
||||
try {
|
||||
sleepBeforeNextCall();
|
||||
GetRecordsResponse getRecordsResult = getRecordsRetrievalStrategy.getRecords(maxRecordsPerCall);
|
||||
lastSuccessfulCall = Instant.now();
|
||||
|
||||
final List<KinesisClientRecord> records = getRecordsResult.records().stream()
|
||||
.map(KinesisClientRecord::fromRecord).collect(Collectors.toList());
|
||||
ProcessRecordsInput processRecordsInput = ProcessRecordsInput.builder()
|
||||
.records(records)
|
||||
.millisBehindLatest(getRecordsResult.millisBehindLatest())
|
||||
.cacheEntryTime(lastSuccessfulCall)
|
||||
.isAtShardEnd(getRecordsRetrievalStrategy.getDataFetcher().isShardEndReached())
|
||||
.build();
|
||||
|
||||
highestSequenceNumber = calculateHighestSequenceNumber(processRecordsInput);
|
||||
PrefetchRecordsRetrieved recordsRetrieved = new PrefetchRecordsRetrieved(processRecordsInput,
|
||||
highestSequenceNumber, getRecordsResult.nextShardIterator());
|
||||
highestSequenceNumber = recordsRetrieved.lastBatchSequenceNumber;
|
||||
addArrivedRecordsInput(recordsRetrieved);
|
||||
drainQueueForRequests();
|
||||
} catch (PositionResetException pse) {
|
||||
throw pse;
|
||||
} catch (InterruptedException e) {
|
||||
log.info("Thread was interrupted, indicating shutdown was called on the cache.");
|
||||
} catch (ExpiredIteratorException e) {
|
||||
log.info("ShardId {}: records threw ExpiredIteratorException - restarting"
|
||||
+ " after greatest seqNum passed to customer", shardId, e);
|
||||
|
||||
scope.addData(EXPIRED_ITERATOR_METRIC, 1, StandardUnit.COUNT, MetricsLevel.SUMMARY);
|
||||
|
||||
dataFetcher.restartIterator();
|
||||
} catch (SdkClientException e) {
|
||||
log.error("Exception thrown while fetching records from Kinesis", e);
|
||||
} catch (Throwable e) {
|
||||
log.error("Unexpected exception was thrown. This could probably be an issue or a bug." +
|
||||
" Please search for the exception/error online to check what is going on. If the " +
|
||||
"issue persists or is a recurring problem, feel free to open an issue on, " +
|
||||
"https://github.com/awslabs/amazon-kinesis-client.", e);
|
||||
} finally {
|
||||
MetricsUtil.endScope(scope);
|
||||
}
|
||||
} else {
|
||||
//
|
||||
// Consumer isn't ready to receive new records will allow prefetch counters to pause
|
||||
//
|
||||
try {
|
||||
prefetchCounters.waitForConsumer();
|
||||
} catch (InterruptedException ie) {
|
||||
log.info("Thread was interrupted while waiting for the consumer. " +
|
||||
"Shutdown has probably been started");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void callShutdownOnStrategy() {
|
||||
if (!getRecordsRetrievalStrategy.isShutdown()) {
|
||||
getRecordsRetrievalStrategy.shutdown();
|
||||
|
|
@ -302,6 +402,11 @@ public class PrefetchRecordsPublisher implements RecordsPublisher {
|
|||
return size < maxRecordsCount && byteSize < maxByteSize;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
size = 0;
|
||||
byteSize = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("{ Requests: %d, Records: %d, Bytes: %d }", getRecordsResultQueue.size(), size,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,447 @@
|
|||
/*
|
||||
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Amazon Software License (the "License").
|
||||
* You may not use this file except in compliance with the License.
|
||||
* A copy of the License is located at
|
||||
*
|
||||
* http://aws.amazon.com/asl/
|
||||
*
|
||||
* or in the "license" file accompanying this file. This file is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
* express or implied. See the License for the specific language governing
|
||||
* permissions and limitations under the License.
|
||||
*/
|
||||
package software.amazon.kinesis.lifecycle;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
import static org.hamcrest.CoreMatchers.nullValue;
|
||||
import static org.junit.Assert.assertThat;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.argThat;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static software.amazon.kinesis.utils.ProcessRecordsInputMatcher.eqProcessRecordsInput;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CyclicBarrier;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TestName;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.invocation.InvocationOnMock;
|
||||
import org.mockito.runners.MockitoJUnitRunner;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.reactivestreams.Subscriber;
|
||||
import org.reactivestreams.Subscription;
|
||||
|
||||
import com.google.common.util.concurrent.ThreadFactoryBuilder;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
|
||||
import software.amazon.kinesis.leases.ShardInfo;
|
||||
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
||||
import software.amazon.kinesis.retrieval.KinesisClientRecord;
|
||||
import software.amazon.kinesis.retrieval.RecordsPublisher;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class ShardConsumerSubscriberTest {
|
||||
|
||||
private final Object processedNotifier = new Object();
|
||||
|
||||
private static final String TERMINAL_MARKER = "Terminal";
|
||||
|
||||
@Mock
|
||||
private ShardConsumer shardConsumer;
|
||||
@Mock
|
||||
private RecordsRetrieved recordsRetrieved;
|
||||
|
||||
private ProcessRecordsInput processRecordsInput;
|
||||
private TestPublisher recordsPublisher;
|
||||
|
||||
private ExecutorService executorService;
|
||||
private int bufferSize = 8;
|
||||
|
||||
private ShardConsumerSubscriber subscriber;
|
||||
|
||||
@Rule
|
||||
public TestName testName = new TestName();
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
executorService = Executors.newFixedThreadPool(8, new ThreadFactoryBuilder()
|
||||
.setNameFormat("test-" + testName.getMethodName() + "-%04d").setDaemon(true).build());
|
||||
recordsPublisher = new TestPublisher();
|
||||
|
||||
ShardInfo shardInfo = new ShardInfo("shard-001", "", Collections.emptyList(),
|
||||
ExtendedSequenceNumber.TRIM_HORIZON);
|
||||
when(shardConsumer.shardInfo()).thenReturn(shardInfo);
|
||||
|
||||
processRecordsInput = ProcessRecordsInput.builder().records(Collections.emptyList())
|
||||
.cacheEntryTime(Instant.now()).build();
|
||||
|
||||
subscriber = new ShardConsumerSubscriber(recordsPublisher, executorService, bufferSize, shardConsumer);
|
||||
when(recordsRetrieved.processRecordsInput()).thenReturn(processRecordsInput);
|
||||
}
|
||||
|
||||
@After
|
||||
public void after() {
|
||||
executorService.shutdownNow();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void singleItemTest() throws Exception {
|
||||
addItemsToReturn(1);
|
||||
|
||||
setupNotifierAnswer(1);
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
subscriber.startSubscriptions();
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
verify(shardConsumer).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void multipleItemTest() throws Exception {
|
||||
addItemsToReturn(100);
|
||||
|
||||
setupNotifierAnswer(recordsPublisher.responses.size());
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
subscriber.startSubscriptions();
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void consumerErrorSkipsEntryTest() throws Exception {
|
||||
addItemsToReturn(20);
|
||||
|
||||
Throwable testException = new Throwable("ShardConsumerError");
|
||||
|
||||
doAnswer(new Answer() {
|
||||
int expectedInvocations = recordsPublisher.responses.size();
|
||||
|
||||
@Override
|
||||
public Object answer(InvocationOnMock invocation) throws Throwable {
|
||||
expectedInvocations--;
|
||||
if (expectedInvocations == 10) {
|
||||
throw testException;
|
||||
}
|
||||
if (expectedInvocations <= 0) {
|
||||
synchronized (processedNotifier) {
|
||||
processedNotifier.notifyAll();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class));
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
subscriber.startSubscriptions();
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
assertThat(subscriber.getAndResetDispatchFailure(), equalTo(testException));
|
||||
assertThat(subscriber.getAndResetDispatchFailure(), nullValue());
|
||||
|
||||
verify(shardConsumer, times(20)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void onErrorStopsProcessingTest() throws Exception {
|
||||
Throwable expected = new Throwable("Wheee");
|
||||
addItemsToReturn(10);
|
||||
recordsPublisher.add(new ResponseItem(expected));
|
||||
addItemsToReturn(10);
|
||||
|
||||
setupNotifierAnswer(10);
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
subscriber.startSubscriptions();
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
for (int attempts = 0; attempts < 10; attempts++) {
|
||||
if (subscriber.retrievalFailure() != null) {
|
||||
break;
|
||||
}
|
||||
Thread.sleep(10);
|
||||
}
|
||||
|
||||
verify(shardConsumer, times(10)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
assertThat(subscriber.retrievalFailure(), equalTo(expected));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void restartAfterErrorTest() throws Exception {
|
||||
Throwable expected = new Throwable("whee");
|
||||
addItemsToReturn(9);
|
||||
RecordsRetrieved edgeRecord = mock(RecordsRetrieved.class);
|
||||
when(edgeRecord.processRecordsInput()).thenReturn(processRecordsInput);
|
||||
recordsPublisher.add(new ResponseItem(edgeRecord));
|
||||
recordsPublisher.add(new ResponseItem(expected));
|
||||
addItemsToReturn(10);
|
||||
|
||||
setupNotifierAnswer(10);
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
subscriber.startSubscriptions();
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
for (int attempts = 0; attempts < 10; attempts++) {
|
||||
if (subscriber.retrievalFailure() != null) {
|
||||
break;
|
||||
}
|
||||
Thread.sleep(100);
|
||||
}
|
||||
|
||||
setupNotifierAnswer(10);
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
assertThat(subscriber.healthCheck(100000), equalTo(expected));
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
assertThat(recordsPublisher.restartedFrom, equalTo(edgeRecord));
|
||||
verify(shardConsumer, times(20)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void restartAfterRequestTimerExpiresTest() throws Exception {
|
||||
|
||||
executorService = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder()
|
||||
.setNameFormat("test-" + testName.getMethodName() + "-%04d").setDaemon(true).build());
|
||||
|
||||
subscriber = new ShardConsumerSubscriber(recordsPublisher, executorService, bufferSize, shardConsumer);
|
||||
addUniqueItem(1);
|
||||
addTerminalMarker(1);
|
||||
|
||||
CyclicBarrier barrier = new CyclicBarrier(2);
|
||||
|
||||
List<ProcessRecordsInput> received = new ArrayList<>();
|
||||
doAnswer(a -> {
|
||||
ProcessRecordsInput input = a.getArgumentAt(0, ProcessRecordsInput.class);
|
||||
received.add(input);
|
||||
if (input.records().stream().anyMatch(r -> StringUtils.startsWith(r.partitionKey(), TERMINAL_MARKER))) {
|
||||
synchronized (processedNotifier) {
|
||||
processedNotifier.notifyAll();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class));
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
subscriber.startSubscriptions();
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
executorService.execute(() -> {
|
||||
try {
|
||||
//
|
||||
// Notify the test as soon as we have started executing, then wait on the post add barrier.
|
||||
//
|
||||
synchronized (processedNotifier) {
|
||||
processedNotifier.notifyAll();
|
||||
}
|
||||
barrier.await();
|
||||
} catch (Exception e) {
|
||||
log.error("Exception while blocking thread", e);
|
||||
}
|
||||
});
|
||||
//
|
||||
// Wait for our blocking thread to control the thread in the executor.
|
||||
//
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
Stream.iterate(2, i -> i + 1).limit(97).forEach(this::addUniqueItem);
|
||||
|
||||
addTerminalMarker(2);
|
||||
|
||||
synchronized (processedNotifier) {
|
||||
assertThat(subscriber.healthCheck(1), nullValue());
|
||||
barrier.await(500, TimeUnit.MILLISECONDS);
|
||||
|
||||
processedNotifier.wait(5000);
|
||||
}
|
||||
|
||||
verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)),
|
||||
any(Subscription.class));
|
||||
|
||||
assertThat(received.size(), equalTo(recordsPublisher.responses.size()));
|
||||
Stream.iterate(0, i -> i + 1).limit(received.size()).forEach(i -> assertThat(received.get(i),
|
||||
eqProcessRecordsInput(recordsPublisher.responses.get(i).recordsRetrieved.processRecordsInput())));
|
||||
|
||||
}
|
||||
|
||||
private void addUniqueItem(int id) {
|
||||
RecordsRetrieved r = mock(RecordsRetrieved.class, "Record-" + id);
|
||||
ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now())
|
||||
.records(Collections.singletonList(KinesisClientRecord.builder().partitionKey("Record-" + id).build()))
|
||||
.build();
|
||||
when(r.processRecordsInput()).thenReturn(input);
|
||||
recordsPublisher.add(new ResponseItem(r));
|
||||
}
|
||||
|
||||
private ProcessRecordsInput addTerminalMarker(int id) {
|
||||
RecordsRetrieved terminalResponse = mock(RecordsRetrieved.class, TERMINAL_MARKER + "-" + id);
|
||||
ProcessRecordsInput terminalInput = ProcessRecordsInput.builder()
|
||||
.records(Collections
|
||||
.singletonList(KinesisClientRecord.builder().partitionKey(TERMINAL_MARKER + "-" + id).build()))
|
||||
.cacheEntryTime(Instant.now()).build();
|
||||
when(terminalResponse.processRecordsInput()).thenReturn(terminalInput);
|
||||
recordsPublisher.add(new ResponseItem(terminalResponse));
|
||||
|
||||
return terminalInput;
|
||||
}
|
||||
|
||||
private void addItemsToReturn(int count) {
|
||||
Stream.iterate(0, i -> i + 1).limit(count)
|
||||
.forEach(i -> recordsPublisher.add(new ResponseItem(recordsRetrieved)));
|
||||
}
|
||||
|
||||
private void setupNotifierAnswer(int expected) {
|
||||
doAnswer(new Answer() {
|
||||
int seen = expected;
|
||||
|
||||
@Override
|
||||
public Object answer(InvocationOnMock invocation) throws Throwable {
|
||||
seen--;
|
||||
if (seen == 0) {
|
||||
synchronized (processedNotifier) {
|
||||
processedNotifier.notifyAll();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class));
|
||||
}
|
||||
|
||||
private class ResponseItem {
|
||||
private final RecordsRetrieved recordsRetrieved;
|
||||
private final Throwable throwable;
|
||||
private int throwCount = 1;
|
||||
|
||||
public ResponseItem(@NonNull RecordsRetrieved recordsRetrieved) {
|
||||
this.recordsRetrieved = recordsRetrieved;
|
||||
this.throwable = null;
|
||||
}
|
||||
|
||||
public ResponseItem(@NonNull Throwable throwable) {
|
||||
this.throwable = throwable;
|
||||
this.recordsRetrieved = null;
|
||||
}
|
||||
}
|
||||
|
||||
private class TestPublisher implements RecordsPublisher {
|
||||
|
||||
private final LinkedList<ResponseItem> responses = new LinkedList<>();
|
||||
private volatile long requested = 0;
|
||||
private int currentIndex = 0;
|
||||
private Subscriber<? super RecordsRetrieved> subscriber;
|
||||
private RecordsRetrieved restartedFrom;
|
||||
|
||||
void add(ResponseItem... toAdd) {
|
||||
responses.addAll(Arrays.asList(toAdd));
|
||||
send();
|
||||
}
|
||||
|
||||
void send() {
|
||||
send(0);
|
||||
}
|
||||
|
||||
synchronized void send(long toRequest) {
|
||||
requested += toRequest;
|
||||
while (requested > 0 && currentIndex < responses.size()) {
|
||||
ResponseItem item = responses.get(currentIndex);
|
||||
currentIndex++;
|
||||
if (item.recordsRetrieved != null) {
|
||||
subscriber.onNext(item.recordsRetrieved);
|
||||
} else {
|
||||
if (item.throwCount > 0) {
|
||||
item.throwCount--;
|
||||
subscriber.onError(item.throwable);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
requested--;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(ExtendedSequenceNumber extendedSequenceNumber,
|
||||
InitialPositionInStreamExtended initialPositionInStreamExtended) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void restartFrom(RecordsRetrieved recordsRetrieved) {
|
||||
restartedFrom = recordsRetrieved;
|
||||
currentIndex = -1;
|
||||
for (int i = 0; i < responses.size(); i++) {
|
||||
ResponseItem item = responses.get(i);
|
||||
if (recordsRetrieved.equals(item.recordsRetrieved)) {
|
||||
currentIndex = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void subscribe(Subscriber<? super RecordsRetrieved> s) {
|
||||
subscriber = s;
|
||||
s.onSubscribe(new Subscription() {
|
||||
@Override
|
||||
public void request(long n) {
|
||||
send(n);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel() {
|
||||
requested = 0;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -70,6 +70,7 @@ import software.amazon.kinesis.leases.ShardInfo;
|
|||
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
||||
import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
|
||||
import software.amazon.kinesis.retrieval.RecordsPublisher;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
||||
|
||||
/**
|
||||
|
|
@ -161,7 +162,7 @@ public class ShardConsumerTest {
|
|||
final CyclicBarrier barrier = new CyclicBarrier(2);
|
||||
final CyclicBarrier requestBarrier = new CyclicBarrier(2);
|
||||
|
||||
Subscriber<? super ProcessRecordsInput> subscriber;
|
||||
Subscriber<? super RecordsRetrieved> subscriber;
|
||||
final Subscription subscription = mock(Subscription.class);
|
||||
|
||||
TestPublisher() {
|
||||
|
|
@ -193,7 +194,7 @@ public class ShardConsumerTest {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void subscribe(Subscriber<? super ProcessRecordsInput> s) {
|
||||
public void subscribe(Subscriber<? super RecordsRetrieved> s) {
|
||||
subscriber = s;
|
||||
subscriber.onSubscribe(subscription);
|
||||
try {
|
||||
|
|
@ -203,6 +204,11 @@ public class ShardConsumerTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void restartFrom(RecordsRetrieved recordsRetrieved) {
|
||||
|
||||
}
|
||||
|
||||
public void awaitSubscription() throws InterruptedException, BrokenBarrierException {
|
||||
barrier.await();
|
||||
barrier.reset();
|
||||
|
|
@ -219,10 +225,10 @@ public class ShardConsumerTest {
|
|||
}
|
||||
|
||||
public void publish() {
|
||||
publish(processRecordsInput);
|
||||
publish(() -> processRecordsInput);
|
||||
}
|
||||
|
||||
public void publish(ProcessRecordsInput input) {
|
||||
public void publish(RecordsRetrieved input) {
|
||||
subscriber.onNext(input);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ import software.amazon.kinesis.common.InitialPositionInStream;
|
|||
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
|
||||
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
||||
import software.amazon.kinesis.retrieval.KinesisClientRecord;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.RetryableRetrievalException;
|
||||
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
||||
|
||||
|
|
@ -62,7 +63,7 @@ public class FanOutRecordsPublisherTest {
|
|||
@Mock
|
||||
private Subscription subscription;
|
||||
@Mock
|
||||
private Subscriber<ProcessRecordsInput> subscriber;
|
||||
private Subscriber<RecordsRetrieved> subscriber;
|
||||
|
||||
private SubscribeToShardEvent batchEvent;
|
||||
|
||||
|
|
@ -80,7 +81,7 @@ public class FanOutRecordsPublisherTest {
|
|||
|
||||
List<ProcessRecordsInput> receivedInput = new ArrayList<>();
|
||||
|
||||
source.subscribe(new Subscriber<ProcessRecordsInput>() {
|
||||
source.subscribe(new Subscriber<RecordsRetrieved>() {
|
||||
Subscription subscription;
|
||||
|
||||
@Override
|
||||
|
|
@ -90,8 +91,8 @@ public class FanOutRecordsPublisherTest {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onNext(ProcessRecordsInput input) {
|
||||
receivedInput.add(input);
|
||||
public void onNext(RecordsRetrieved input) {
|
||||
receivedInput.add(input.processRecordsInput());
|
||||
subscription.request(1);
|
||||
}
|
||||
|
||||
|
|
@ -147,7 +148,7 @@ public class FanOutRecordsPublisherTest {
|
|||
|
||||
List<ProcessRecordsInput> receivedInput = new ArrayList<>();
|
||||
|
||||
source.subscribe(new Subscriber<ProcessRecordsInput>() {
|
||||
source.subscribe(new Subscriber<RecordsRetrieved>() {
|
||||
Subscription subscription;
|
||||
|
||||
@Override
|
||||
|
|
@ -157,8 +158,8 @@ public class FanOutRecordsPublisherTest {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onNext(ProcessRecordsInput input) {
|
||||
receivedInput.add(input);
|
||||
public void onNext(RecordsRetrieved input) {
|
||||
receivedInput.add(input.processRecordsInput());
|
||||
subscription.request(1);
|
||||
}
|
||||
|
||||
|
|
@ -206,7 +207,7 @@ public class FanOutRecordsPublisherTest {
|
|||
|
||||
ArgumentCaptor<FanOutRecordsPublisher.RecordFlow> flowCaptor = ArgumentCaptor
|
||||
.forClass(FanOutRecordsPublisher.RecordFlow.class);
|
||||
ArgumentCaptor<ProcessRecordsInput> inputCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class);
|
||||
ArgumentCaptor<RecordsRetrieved> inputCaptor = ArgumentCaptor.forClass(RecordsRetrieved.class);
|
||||
|
||||
source.subscribe(subscriber);
|
||||
|
||||
|
|
@ -219,7 +220,7 @@ public class FanOutRecordsPublisherTest {
|
|||
verify(subscriber).onNext(inputCaptor.capture());
|
||||
verify(subscriber).onComplete();
|
||||
|
||||
ProcessRecordsInput input = inputCaptor.getValue();
|
||||
ProcessRecordsInput input = inputCaptor.getValue().processRecordsInput();
|
||||
assertThat(input.isAtShardEnd(), equalTo(true));
|
||||
assertThat(input.records().isEmpty(), equalTo(true));
|
||||
}
|
||||
|
|
@ -325,7 +326,7 @@ public class FanOutRecordsPublisherTest {
|
|||
}
|
||||
}
|
||||
|
||||
private static class NonFailingSubscriber implements Subscriber<ProcessRecordsInput> {
|
||||
private static class NonFailingSubscriber implements Subscriber<RecordsRetrieved> {
|
||||
final List<ProcessRecordsInput> received = new ArrayList<>();
|
||||
Subscription subscription;
|
||||
|
||||
|
|
@ -336,8 +337,8 @@ public class FanOutRecordsPublisherTest {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onNext(ProcessRecordsInput input) {
|
||||
received.add(input);
|
||||
public void onNext(RecordsRetrieved input) {
|
||||
received.add(input.processRecordsInput());
|
||||
subscription.request(1);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -121,13 +121,13 @@ public class PrefetchRecordsPublisherIntegrationTest {
|
|||
getRecordsCache.start(extendedSequenceNumber, initialPosition);
|
||||
sleep(IDLE_MILLIS_BETWEEN_CALLS);
|
||||
|
||||
ProcessRecordsInput processRecordsInput1 = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput processRecordsInput1 = getRecordsCache.getNextResult().processRecordsInput();
|
||||
|
||||
assertTrue(processRecordsInput1.records().isEmpty());
|
||||
assertEquals(processRecordsInput1.millisBehindLatest(), new Long(1000));
|
||||
assertNotNull(processRecordsInput1.cacheEntryTime());
|
||||
|
||||
ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult().processRecordsInput();
|
||||
|
||||
assertNotEquals(processRecordsInput1, processRecordsInput2);
|
||||
}
|
||||
|
|
@ -139,8 +139,8 @@ public class PrefetchRecordsPublisherIntegrationTest {
|
|||
|
||||
assertEquals(getRecordsCache.getRecordsResultQueue.size(), MAX_SIZE);
|
||||
|
||||
ProcessRecordsInput processRecordsInput1 = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput processRecordsInput1 = getRecordsCache.getNextResult().processRecordsInput();
|
||||
ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult().processRecordsInput();
|
||||
|
||||
assertNotEquals(processRecordsInput1, processRecordsInput2);
|
||||
}
|
||||
|
|
@ -179,9 +179,9 @@ public class PrefetchRecordsPublisherIntegrationTest {
|
|||
|
||||
sleep(IDLE_MILLIS_BETWEEN_CALLS);
|
||||
|
||||
ProcessRecordsInput p1 = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput p1 = getRecordsCache.getNextResult().processRecordsInput();
|
||||
|
||||
ProcessRecordsInput p2 = recordsPublisher2.getNextResult();
|
||||
ProcessRecordsInput p2 = recordsPublisher2.getNextResult().processRecordsInput();
|
||||
|
||||
assertNotEquals(p1, p2);
|
||||
assertTrue(p1.records().isEmpty());
|
||||
|
|
@ -207,7 +207,7 @@ public class PrefetchRecordsPublisherIntegrationTest {
|
|||
getRecordsCache.start(extendedSequenceNumber, initialPosition);
|
||||
sleep(IDLE_MILLIS_BETWEEN_CALLS);
|
||||
|
||||
ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult().processRecordsInput();
|
||||
|
||||
assertNotNull(processRecordsInput);
|
||||
assertTrue(processRecordsInput.records().isEmpty());
|
||||
|
|
|
|||
|
|
@ -24,17 +24,21 @@ import static org.junit.Assert.assertTrue;
|
|||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyInt;
|
||||
import static org.mockito.Matchers.anyString;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.atLeast;
|
||||
import static org.mockito.Mockito.atMost;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doNothing;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static software.amazon.kinesis.utils.ProcessRecordsInputMatcher.eqProcessRecordsInput;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
|
@ -42,15 +46,18 @@ import java.util.concurrent.LinkedBlockingQueue;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Mockito;
|
||||
import org.mockito.invocation.InvocationOnMock;
|
||||
import org.mockito.runners.MockitoJUnitRunner;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.reactivestreams.Subscriber;
|
||||
import org.reactivestreams.Subscription;
|
||||
|
||||
|
|
@ -66,6 +73,7 @@ import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
|||
import software.amazon.kinesis.metrics.NullMetricsFactory;
|
||||
import software.amazon.kinesis.retrieval.GetRecordsRetrievalStrategy;
|
||||
import software.amazon.kinesis.retrieval.KinesisClientRecord;
|
||||
import software.amazon.kinesis.retrieval.RecordsRetrieved;
|
||||
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
|
||||
|
||||
/**
|
||||
|
|
@ -92,7 +100,7 @@ public class PrefetchRecordsPublisherTest {
|
|||
|
||||
private List<Record> records;
|
||||
private ExecutorService executorService;
|
||||
private LinkedBlockingQueue<ProcessRecordsInput> spyQueue;
|
||||
private LinkedBlockingQueue<PrefetchRecordsPublisher.PrefetchRecordsRetrieved> spyQueue;
|
||||
private PrefetchRecordsPublisher getRecordsCache;
|
||||
private String operation = "ProcessTask";
|
||||
private GetRecordsResponse getRecordsResponse;
|
||||
|
|
@ -131,7 +139,7 @@ public class PrefetchRecordsPublisherTest {
|
|||
.map(KinesisClientRecord::fromRecord).collect(Collectors.toList());
|
||||
|
||||
getRecordsCache.start(sequenceNumber, initialPosition);
|
||||
ProcessRecordsInput result = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput result = getRecordsCache.getNextResult().processRecordsInput();
|
||||
|
||||
assertEquals(expectedRecords, result.records());
|
||||
|
||||
|
|
@ -200,7 +208,7 @@ public class PrefetchRecordsPublisherTest {
|
|||
.map(KinesisClientRecord::fromRecord).collect(Collectors.toList());
|
||||
|
||||
getRecordsCache.start(sequenceNumber, initialPosition);
|
||||
ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult().processRecordsInput();
|
||||
|
||||
verify(executorService).execute(any());
|
||||
assertEquals(expectedRecords, processRecordsInput.records());
|
||||
|
|
@ -209,7 +217,7 @@ public class PrefetchRecordsPublisherTest {
|
|||
|
||||
sleep(2000);
|
||||
|
||||
ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult();
|
||||
ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult().processRecordsInput();
|
||||
assertNotEquals(processRecordsInput, processRecordsInput2);
|
||||
assertEquals(expectedRecords, processRecordsInput2.records());
|
||||
assertNotEquals(processRecordsInput2.timeSpentInCache(), Duration.ZERO);
|
||||
|
|
@ -276,7 +284,7 @@ public class PrefetchRecordsPublisherTest {
|
|||
|
||||
Object lock = new Object();
|
||||
|
||||
Subscriber<ProcessRecordsInput> subscriber = new Subscriber<ProcessRecordsInput>() {
|
||||
Subscriber<RecordsRetrieved> subscriber = new Subscriber<RecordsRetrieved>() {
|
||||
Subscription sub;
|
||||
|
||||
@Override
|
||||
|
|
@ -286,7 +294,7 @@ public class PrefetchRecordsPublisherTest {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onNext(ProcessRecordsInput processRecordsInput) {
|
||||
public void onNext(RecordsRetrieved recordsRetrieved) {
|
||||
receivedItems.incrementAndGet();
|
||||
if (receivedItems.get() >= expectedItems) {
|
||||
synchronized (lock) {
|
||||
|
|
@ -325,6 +333,87 @@ public class PrefetchRecordsPublisherTest {
|
|||
assertThat(receivedItems.get(), equalTo(expectedItems));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResetClearsRemainingData() {
|
||||
List<GetRecordsResponse> responses = Stream.iterate(0, i -> i + 1).limit(10).map(i -> {
|
||||
Record record = Record.builder().partitionKey("record-" + i).sequenceNumber("seq-" + i)
|
||||
.data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).approximateArrivalTimestamp(Instant.now())
|
||||
.build();
|
||||
String nextIterator = "shard-iter-" + (i + 1);
|
||||
return GetRecordsResponse.builder().records(record).nextShardIterator(nextIterator).build();
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
RetrieverAnswer retrieverAnswer = new RetrieverAnswer(responses);
|
||||
|
||||
when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenAnswer(retrieverAnswer);
|
||||
doAnswer(a -> {
|
||||
String resetTo = a.getArgumentAt(0, String.class);
|
||||
retrieverAnswer.resetIteratorTo(resetTo);
|
||||
return null;
|
||||
}).when(dataFetcher).resetIterator(anyString(), anyString(), any());
|
||||
|
||||
getRecordsCache.start(sequenceNumber, initialPosition);
|
||||
|
||||
RecordsRetrieved lastProcessed = getRecordsCache.getNextResult();
|
||||
RecordsRetrieved expected = getRecordsCache.getNextResult();
|
||||
|
||||
//
|
||||
// Skip some of the records the cache
|
||||
//
|
||||
getRecordsCache.getNextResult();
|
||||
getRecordsCache.getNextResult();
|
||||
|
||||
verify(getRecordsRetrievalStrategy, atLeast(2)).getRecords(anyInt());
|
||||
|
||||
while(getRecordsCache.getRecordsResultQueue.remainingCapacity() > 0) {
|
||||
Thread.yield();
|
||||
}
|
||||
|
||||
getRecordsCache.restartFrom(lastProcessed);
|
||||
RecordsRetrieved postRestart = getRecordsCache.getNextResult();
|
||||
|
||||
assertThat(postRestart.processRecordsInput(), eqProcessRecordsInput(expected.processRecordsInput()));
|
||||
verify(dataFetcher).resetIterator(eq(responses.get(0).nextShardIterator()),
|
||||
eq(responses.get(0).records().get(0).sequenceNumber()), any());
|
||||
|
||||
}
|
||||
|
||||
private static class RetrieverAnswer implements Answer<GetRecordsResponse> {
|
||||
|
||||
private final List<GetRecordsResponse> responses;
|
||||
private Iterator<GetRecordsResponse> iterator;
|
||||
|
||||
public RetrieverAnswer(List<GetRecordsResponse> responses) {
|
||||
this.responses = responses;
|
||||
this.iterator = responses.iterator();
|
||||
}
|
||||
|
||||
public void resetIteratorTo(String nextIterator) {
|
||||
Iterator<GetRecordsResponse> newIterator = responses.iterator();
|
||||
while(newIterator.hasNext()) {
|
||||
GetRecordsResponse current = newIterator.next();
|
||||
if (StringUtils.equals(nextIterator, current.nextShardIterator())) {
|
||||
if (!newIterator.hasNext()) {
|
||||
iterator = responses.iterator();
|
||||
} else {
|
||||
newIterator.next();
|
||||
iterator = newIterator;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public GetRecordsResponse answer(InvocationOnMock invocation) throws Throwable {
|
||||
GetRecordsResponse response = iterator.next();
|
||||
if (!iterator.hasNext()) {
|
||||
iterator = responses.iterator();
|
||||
}
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() {
|
||||
getRecordsCache.shutdown();
|
||||
|
|
@ -340,4 +429,5 @@ public class PrefetchRecordsPublisherTest {
|
|||
private SdkBytes createByteBufferWithSize(int size) {
|
||||
return SdkBytes.fromByteArray(new byte[size]);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Amazon Software License (the "License").
|
||||
* You may not use this file except in compliance with the License.
|
||||
* A copy of the License is located at
|
||||
*
|
||||
* http://aws.amazon.com/asl/
|
||||
*
|
||||
* or in the "license" file accompanying this file. This file is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
* express or implied. See the License for the specific language governing
|
||||
* permissions and limitations under the License.
|
||||
*/
|
||||
|
||||
package software.amazon.kinesis.utils;
|
||||
|
||||
import lombok.Data;
|
||||
import org.hamcrest.Description;
|
||||
import org.hamcrest.Matcher;
|
||||
import org.hamcrest.TypeSafeDiagnosingMatcher;
|
||||
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
import static org.hamcrest.CoreMatchers.nullValue;
|
||||
|
||||
public class ProcessRecordsInputMatcher extends TypeSafeDiagnosingMatcher<ProcessRecordsInput> {
|
||||
|
||||
private final ProcessRecordsInput template;
|
||||
private final Map<String, MatcherData> matchers = new HashMap<>();
|
||||
|
||||
public static ProcessRecordsInputMatcher eqProcessRecordsInput(ProcessRecordsInput expected) {
|
||||
return new ProcessRecordsInputMatcher(expected);
|
||||
}
|
||||
|
||||
public ProcessRecordsInputMatcher(ProcessRecordsInput template) {
|
||||
matchers.put("cacheEntryTime",
|
||||
nullOrEquals(template.cacheEntryTime(), ProcessRecordsInput::cacheEntryTime));
|
||||
matchers.put("checkpointer", nullOrEquals(template.checkpointer(), ProcessRecordsInput::checkpointer));
|
||||
matchers.put("isAtShardEnd", nullOrEquals(template.isAtShardEnd(), ProcessRecordsInput::isAtShardEnd));
|
||||
matchers.put("millisBehindLatest",
|
||||
nullOrEquals(template.millisBehindLatest(), ProcessRecordsInput::millisBehindLatest));
|
||||
matchers.put("records", nullOrEquals(template.records(), ProcessRecordsInput::records));
|
||||
|
||||
this.template = template;
|
||||
}
|
||||
|
||||
private static MatcherData nullOrEquals(Object item, Function<ProcessRecordsInput, ?> accessor) {
|
||||
if (item == null) {
|
||||
return new MatcherData(nullValue(), accessor);
|
||||
}
|
||||
return new MatcherData(equalTo(item), accessor);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean matchesSafely(ProcessRecordsInput item, Description mismatchDescription) {
|
||||
return matchers.entrySet().stream()
|
||||
.filter(e -> e.getValue().matcher.matches(e.getValue().accessor.apply(item))).anyMatch(e -> {
|
||||
mismatchDescription.appendText(e.getKey()).appendText(" ");
|
||||
e.getValue().matcher.describeMismatch(e.getValue().accessor.apply(item), mismatchDescription);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void describeTo(Description description) {
|
||||
matchers.forEach((k, v) -> description.appendText(k).appendText(" ").appendDescriptionOf(v.matcher));
|
||||
}
|
||||
|
||||
@Data
|
||||
private static class MatcherData {
|
||||
private final Matcher<?> matcher;
|
||||
private final Function<ProcessRecordsInput, ?> accessor;
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue