diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java index 8cc105e1..1a2c371e 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.java @@ -15,6 +15,7 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; import java.util.Date; +import java.util.Optional; import java.util.Set; import org.apache.commons.lang.Validate; @@ -25,7 +26,6 @@ import com.amazonaws.regions.RegionUtils; import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsScope; import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel; -import com.google.common.base.Optional; import com.google.common.collect.ImmutableSet; import lombok.Getter; @@ -215,7 +215,7 @@ public class KinesisClientLibConfiguration { private ShardPrioritization shardPrioritization; @Getter - private Optional timeoutInSeconds = Optional.absent(); + private Optional timeoutInSeconds = Optional.empty(); @Getter private int maxLeaseRenewalThreads = DEFAULT_MAX_LEASE_RENEWAL_THREADS; diff --git a/src/main/java/com/amazonaws/services/kinesis/multilang/MultiLangProtocol.java b/src/main/java/com/amazonaws/services/kinesis/multilang/MultiLangProtocol.java index 99338968..dfac215f 100644 --- a/src/main/java/com/amazonaws/services/kinesis/multilang/MultiLangProtocol.java +++ b/src/main/java/com/amazonaws/services/kinesis/multilang/MultiLangProtocol.java @@ -14,11 +14,6 @@ */ package com.amazonaws.services.kinesis.multilang; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException; import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; @@ -32,9 +27,14 @@ import com.amazonaws.services.kinesis.multilang.messages.ProcessRecordsMessage; import com.amazonaws.services.kinesis.multilang.messages.ShutdownMessage; import com.amazonaws.services.kinesis.multilang.messages.ShutdownRequestedMessage; import com.amazonaws.services.kinesis.multilang.messages.StatusMessage; - import lombok.extern.apachecommons.CommonsLog; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + /** * An implementation of the multi language protocol. */ @@ -48,7 +48,7 @@ class MultiLangProtocol { /** * Constructor. - * + * * @param messageReader * A message reader. * @param messageWriter @@ -67,7 +67,7 @@ class MultiLangProtocol { /** * Writes an {@link InitializeMessage} to the child process's STDIN and waits for the child process to respond with * a {@link StatusMessage} on its STDOUT. - * + * * @return Whether or not this operation succeeded. */ boolean initialize() { @@ -82,7 +82,7 @@ class MultiLangProtocol { /** * Writes a {@link ProcessRecordsMessage} to the child process's STDIN and waits for the child process to respond * with a {@link StatusMessage} on its STDOUT. - * + * * @param processRecordsInput * The records, and associated metadata, to process. * @return Whether or not this operation succeeded. @@ -95,7 +95,7 @@ class MultiLangProtocol { /** * Writes a {@link ShutdownMessage} to the child process's STDIN and waits for the child process to respond with a * {@link StatusMessage} on its STDOUT. - * + * * @param checkpointer A checkpointer. * @param reason Why this processor is being shutdown. * @return Whether or not this operation succeeded. @@ -124,7 +124,7 @@ class MultiLangProtocol { * all communications with the child process regarding checkpointing were successful. Note that whether or not the * checkpointing itself was successful is not the concern of this method. This method simply cares whether it was * able to successfully communicate the results of its attempts to checkpoint. - * + * * @param action * What action is being waited on. * @param checkpointer @@ -155,7 +155,7 @@ class MultiLangProtocol { /** * Waits for status message and verifies it against the expectation - * + * * @param action * What action is being waited on. * @param checkpointer @@ -166,41 +166,72 @@ class MultiLangProtocol { StatusMessage statusMessage = null; while (statusMessage == null) { Future future = this.messageReader.getNextMessageFromSTDOUT(); - try { - Message message; - if (configuration.getTimeoutInSeconds().isPresent() && configuration.getTimeoutInSeconds().get() > 0) { - message = future.get(configuration.getTimeoutInSeconds().get(), TimeUnit.SECONDS); - } else { - message = future.get(); - } - // Note that instanceof doubles as a check against a value being null - if (message instanceof CheckpointMessage) { - boolean checkpointWriteSucceeded = checkpoint((CheckpointMessage) message, checkpointer).get(); - if (!checkpointWriteSucceeded) { - return false; - } - } else if (message instanceof StatusMessage) { - statusMessage = (StatusMessage) message; - } - } catch (InterruptedException e) { - log.error(String.format("Interrupted while waiting for %s message for shard %s", action, - initializationInput.getShardId())); + Optional message = configuration.getTimeoutInSeconds().map(second -> + futureMethod(() -> future.get(second, TimeUnit.SECONDS), action)).orElse(futureMethod(future::get, action)); + + if (!message.isPresent()) { return false; - } catch (ExecutionException e) { - log.error(String.format("Failed to get status message for %s action for shard %s", action, - initializationInput.getShardId()), e); - return false; - } catch (TimeoutException e) { - log.error(String.format("Timedout to get status message for %s action for shard %s. Terminating...", - action, - initializationInput.getShardId()), - e); - haltJvm(1); } + + Optional booleanStatusMessage = message.flatMap(m -> { + if (m instanceof CheckpointMessage) { + return Optional.of(futureMethod(() -> checkpoint((CheckpointMessage) m, checkpointer).get())); + } + return Optional.empty(); + }); + + Message m = message.get(); + + if (booleanStatusMessage.isPresent() && !booleanStatusMessage.get()) { + return false; + } else if (!booleanStatusMessage.isPresent() && m instanceof StatusMessage) { + statusMessage = (StatusMessage) m; + } + // Note that instanceof doubles as a check against a value being null } return this.validateStatusMessage(statusMessage, action); } + private interface FutureMethod { + Message get() throws InterruptedException, TimeoutException, ExecutionException; + } + + private Optional futureMethod(FutureMethod fm, String action) { + try { + return Optional.of(fm.get()); + } catch (InterruptedException e) { + log.error(String.format("Interrupted while waiting for %s message for shard %s", action, + initializationInput.getShardId()), e); + } catch (ExecutionException e) { + log.error(String.format("Failed to get status message for %s action for shard %s", action, + initializationInput.getShardId()), e); + } catch (TimeoutException e) { + log.error(String.format("Timedout to get status message for %s action for shard %s. Terminating...", + action, + initializationInput.getShardId()), + e); + haltJvm(1); + } + return Optional.empty(); + } + + private interface CheckpointFutureMethod { + Boolean get() throws InterruptedException, ExecutionException; + } + + private Boolean futureMethod(CheckpointFutureMethod cfm) { + try { + return cfm.get(); + } catch (InterruptedException e) { + log.error(String.format("Interrupted while waiting for Checkpointing message for shard %s", + initializationInput.getShardId()), e); + } catch (ExecutionException e) { + log.error(String.format("Failed to get status message for Checkpointing action for shard %s", + initializationInput.getShardId()), e); + } + return false; + } + /** * This method is used to halt the JVM. Use this method with utmost caution, since this method will kill the JVM * without calling the Shutdown hooks. @@ -213,7 +244,7 @@ class MultiLangProtocol { /** * Utility for confirming that the status message is for the provided action. - * + * * @param statusMessage The status of the child process. * @param action The action that was being waited on. * @return Whether or not this operation succeeded. @@ -231,7 +262,7 @@ class MultiLangProtocol { * provided {@link CheckpointMessage}. If no sequence number is provided, i.e. the sequence number is null, then * this method will call {@link IRecordProcessorCheckpointer#checkpoint()}. The method returns a future representing * the attempt to write the result of this checkpoint attempt to the child process. - * + * * @param checkpointMessage A checkpoint message. * @param checkpointer A checkpointer. * @return Whether or not this operation succeeded. diff --git a/src/test/java/com/amazonaws/services/kinesis/multilang/MultiLangProtocolTest.java b/src/test/java/com/amazonaws/services/kinesis/multilang/MultiLangProtocolTest.java index 4b74e728..3f35b8fa 100644 --- a/src/test/java/com/amazonaws/services/kinesis/multilang/MultiLangProtocolTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/multilang/MultiLangProtocolTest.java @@ -28,7 +28,6 @@ import com.amazonaws.services.kinesis.multilang.messages.CheckpointMessage; import com.amazonaws.services.kinesis.multilang.messages.Message; import com.amazonaws.services.kinesis.multilang.messages.ProcessRecordsMessage; import com.amazonaws.services.kinesis.multilang.messages.StatusMessage; -import com.google.common.base.Optional; import com.google.common.util.concurrent.SettableFuture; import org.junit.Before; import org.junit.Test; @@ -43,6 +42,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; @@ -85,7 +85,7 @@ public class MultiLangProtocolTest { protocol = new MultiLangProtocolForTesting(messageReader, messageWriter, new InitializationInput().withShardId(shardId), configuration); - when(configuration.getTimeoutInSeconds()).thenReturn(Optional.absent()); + when(configuration.getTimeoutInSeconds()).thenReturn(Optional.empty()); } private Future buildFuture(T value) { @@ -179,7 +179,10 @@ public class MultiLangProtocolTest { this.add(new StatusMessage("processRecords")); } })); - assertThat(protocol.processRecords(new ProcessRecordsInput().withRecords(EMPTY_RECORD_LIST).withCheckpointer(checkpointer)), equalTo(true)); + + boolean result = protocol.processRecords(new ProcessRecordsInput().withRecords(EMPTY_RECORD_LIST).withCheckpointer(checkpointer)); + + assertThat(result, equalTo(true)); verify(checkpointer, timeout(1)).checkpoint(); verify(checkpointer, timeout(1)).checkpoint("123", 0L); diff --git a/src/test/java/com/amazonaws/services/kinesis/multilang/StreamingRecordProcessorTest.java b/src/test/java/com/amazonaws/services/kinesis/multilang/StreamingRecordProcessorTest.java index 6cb17863..d27b9480 100644 --- a/src/test/java/com/amazonaws/services/kinesis/multilang/StreamingRecordProcessorTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/multilang/StreamingRecordProcessorTest.java @@ -14,43 +14,13 @@ */ package com.amazonaws.services.kinesis.multilang; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.anyString; -import static org.mockito.Matchers.argThat; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason; -import com.google.common.base.Optional; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.runners.MockitoJUnitRunner; -import org.mockito.stubbing.Answer; - import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException; import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException; import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException; import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException; import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason; import com.amazonaws.services.kinesis.clientlibrary.types.InitializationInput; import com.amazonaws.services.kinesis.clientlibrary.types.ProcessRecordsInput; import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownInput; @@ -61,6 +31,35 @@ import com.amazonaws.services.kinesis.multilang.messages.ProcessRecordsMessage; import com.amazonaws.services.kinesis.multilang.messages.ShutdownMessage; import com.amazonaws.services.kinesis.multilang.messages.StatusMessage; import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.argThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class StreamingRecordProcessorTest { @@ -126,7 +125,7 @@ public class StreamingRecordProcessorTest { messageWriter = Mockito.mock(MessageWriter.class); messageReader = Mockito.mock(MessageReader.class); errorReader = Mockito.mock(DrainChildSTDERRTask.class); - when(configuration.getTimeoutInSeconds()).thenReturn(Optional.absent()); + when(configuration.getTimeoutInSeconds()).thenReturn(Optional.empty()); recordProcessor = new MultiLangRecordProcessor(new ProcessBuilder(), executor, new ObjectMapper(), messageWriter,