Adding test cases for halt jvm code. Made the configuration objects for timeout optional.

This commit is contained in:
Sahil Palvia 2017-07-19 12:40:05 -07:00
parent 644a55bccb
commit e80201b047
6 changed files with 109 additions and 26 deletions

View file

@ -23,7 +23,9 @@ import com.amazonaws.regions.RegionUtils;
import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper;
import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsScope; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsScope;
import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel; import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import lombok.Getter;
/** /**
* Configuration for the Amazon Kinesis Client Library. * Configuration for the Amazon Kinesis Client Library.
@ -203,8 +205,10 @@ public class KinesisClientLibConfiguration {
// This is useful for optimizing deployments to large fleets working on a stable stream. // This is useful for optimizing deployments to large fleets working on a stable stream.
private boolean skipShardSyncAtWorkerInitializationIfLeasesExist; private boolean skipShardSyncAtWorkerInitializationIfLeasesExist;
private ShardPrioritization shardPrioritization; private ShardPrioritization shardPrioritization;
private boolean timeoutEnabled; @Getter
private int timeoutInSeconds; private Optional<Boolean> timeoutEnabled = Optional.absent();
@Getter
private Optional<Integer> timeoutInSeconds = Optional.absent();
/** /**
* Constructor. * Constructor.
@ -1082,27 +1086,14 @@ public class KinesisClientLibConfiguration {
* @param timeoutEnabled Enable or disbale MultiLangProtocol to wait for the records to be processed * @param timeoutEnabled Enable or disbale MultiLangProtocol to wait for the records to be processed
*/ */
public void withTimeoutEnabled(final boolean timeoutEnabled) { public void withTimeoutEnabled(final boolean timeoutEnabled) {
this.timeoutEnabled = timeoutEnabled; this.timeoutEnabled = Optional.of(timeoutEnabled);
}
/**
* @return If timeout is enabled for MultiLangProtocol to wait for records to be processed
*/
public boolean isTimeoutEnabled() {
return timeoutEnabled;
} }
/** /**
* @param timeoutInSeconds The timeout in seconds to wait for the MultiLangProtocol to wait for * @param timeoutInSeconds The timeout in seconds to wait for the MultiLangProtocol to wait for
*/ */
public void withTimeoutInSeconds(final int timeoutInSeconds) { public void withTimeoutInSeconds(final int timeoutInSeconds) {
this.timeoutInSeconds = timeoutInSeconds; this.timeoutInSeconds = Optional.of(timeoutInSeconds);
} }
/**
* @return Time for MultiLangProtocol to wait to get response, before throwing an exception.
*/
public int getTimeoutInSeconds() {
return timeoutInSeconds;
}
} }

View file

@ -162,14 +162,17 @@ class MultiLangProtocol {
* the original process records request * the original process records request
* @return Whether or not this operation succeeded. * @return Whether or not this operation succeeded.
*/ */
private boolean waitForStatusMessage(String action, IRecordProcessorCheckpointer checkpointer) { boolean waitForStatusMessage(String action, IRecordProcessorCheckpointer checkpointer) {
StatusMessage statusMessage = null; StatusMessage statusMessage = null;
while (statusMessage == null) { while (statusMessage == null) {
Future<Message> future = this.messageReader.getNextMessageFromSTDOUT(); Future<Message> future = this.messageReader.getNextMessageFromSTDOUT();
try { try {
Message message; Message message;
if (configuration.isTimeoutEnabled()) { if (configuration.getTimeoutEnabled().isPresent() && configuration.getTimeoutEnabled().get()) {
message = future.get(configuration.getTimeoutInSeconds(), TimeUnit.SECONDS); if (!configuration.getTimeoutInSeconds().isPresent()) {
throw new IllegalArgumentException("timeoutInSeconds property should be set if timeoutEnabled is true");
}
message = future.get(configuration.getTimeoutInSeconds().get(), TimeUnit.SECONDS);
} else { } else {
message = future.get(); message = future.get();
} }
@ -195,12 +198,22 @@ class MultiLangProtocol {
action, action,
initializationInput.getShardId()), initializationInput.getShardId()),
e); e);
Runtime.getRuntime().halt(1); haltJvm(1);
} }
} }
return this.validateStatusMessage(statusMessage, action); return this.validateStatusMessage(statusMessage, action);
} }
/**
* This method is used to halt the JVM. Use this method with utmost caution, since this method will kill the JVM
* without calling the Shutdown hooks.
*
* @param exitStatus The exit status with which the JVM is to be halted.
*/
protected void haltJvm(int exitStatus) {
Runtime.getRuntime().halt(exitStatus);
}
/** /**
* Utility for confirming that the status message is for the provided action. * Utility for confirming that the status message is for the provided action.
* *

View file

@ -0,0 +1,29 @@
package com.amazonaws.services.kinesis.multilang;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration;
import com.amazonaws.services.kinesis.clientlibrary.types.InitializationInput;
/**
*
*/
public class MultiLangProtocolForTests extends MultiLangProtocol {
/**
* Constructor.
*
* @param messageReader A message reader.
* @param messageWriter A message writer.
* @param initializationInput
* @param configuration
*/
MultiLangProtocolForTests(final MessageReader messageReader,
final MessageWriter messageWriter,
final InitializationInput initializationInput,
final KinesisClientLibConfiguration configuration) {
super(messageReader, messageWriter, initializationInput, configuration);
}
@Override
protected void haltJvm(final int exitStatus) {
throw new RuntimeException("Halt called");
}
}

View file

@ -15,11 +15,17 @@
package com.amazonaws.services.kinesis.multilang; package com.amazonaws.services.kinesis.multilang;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -30,8 +36,11 @@ import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration;
import com.google.common.base.Optional;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
@ -58,7 +67,7 @@ import com.google.common.util.concurrent.SettableFuture;
public class MultiLangProtocolTest { public class MultiLangProtocolTest {
private static final List<Record> EMPTY_RECORD_LIST = Collections.emptyList(); private static final List<Record> EMPTY_RECORD_LIST = Collections.emptyList();
private MultiLangProtocol protocol; private MultiLangProtocolForTests protocol;
private MessageWriter messageWriter; private MessageWriter messageWriter;
private MessageReader messageReader; private MessageReader messageReader;
private String shardId; private String shardId;
@ -73,9 +82,11 @@ public class MultiLangProtocolTest {
messageWriter = Mockito.mock(MessageWriter.class); messageWriter = Mockito.mock(MessageWriter.class);
messageReader = Mockito.mock(MessageReader.class); messageReader = Mockito.mock(MessageReader.class);
configuration = Mockito.mock(KinesisClientLibConfiguration.class); configuration = Mockito.mock(KinesisClientLibConfiguration.class);
protocol = new MultiLangProtocol(messageReader, messageWriter, new InitializationInput().withShardId(shardId), protocol = new MultiLangProtocolForTests(messageReader, messageWriter, new InitializationInput().withShardId(shardId),
configuration); configuration);
checkpointer = Mockito.mock(IRecordProcessorCheckpointer.class); checkpointer = Mockito.mock(IRecordProcessorCheckpointer.class);
when(configuration.getTimeoutEnabled()).thenReturn(Optional.<Boolean>absent());
} }
private <T> Future<T> buildFuture(T value) { private <T> Future<T> buildFuture(T value) {
@ -187,4 +198,40 @@ public class MultiLangProtocolTest {
})); }));
assertThat(protocol.processRecords(new ProcessRecordsInput().withRecords(EMPTY_RECORD_LIST).withCheckpointer(checkpointer)), equalTo(false)); assertThat(protocol.processRecords(new ProcessRecordsInput().withRecords(EMPTY_RECORD_LIST).withCheckpointer(checkpointer)), equalTo(false));
} }
@Test(expected = RuntimeException.class)
public void waitForStatusMessageTimeoutTest() throws InterruptedException, TimeoutException, ExecutionException {
when(messageWriter.writeProcessRecordsMessage(any(ProcessRecordsInput.class))).thenReturn(buildFuture(true));
Future<Message> future = Mockito.mock(Future.class);
when(messageReader.getNextMessageFromSTDOUT()).thenReturn(future);
when(configuration.getTimeoutEnabled()).thenReturn(Optional.of(true));
when(configuration.getTimeoutInSeconds()).thenReturn(Optional.of(5));
when(future.get(anyInt(), eq(TimeUnit.SECONDS))).thenThrow(TimeoutException.class);
protocol = new MultiLangProtocolForTests(messageReader,
messageWriter,
new InitializationInput().withShardId(shardId),
configuration);
protocol.processRecords(new ProcessRecordsInput().withRecords(EMPTY_RECORD_LIST));
}
@Test(expected = IllegalArgumentException.class)
public void waitForStatusMessageTimeoutErrorTest() {
when(messageWriter.writeProcessRecordsMessage(any(ProcessRecordsInput.class))).thenReturn(buildFuture(true));
when(messageReader.getNextMessageFromSTDOUT()).thenReturn(buildFuture(new StatusMessage("processRecords"), Message.class));
when(configuration.getTimeoutEnabled()).thenReturn(Optional.of(true));
when(configuration.getTimeoutInSeconds()).thenReturn(Optional.<Integer>absent());
protocol.processRecords(new ProcessRecordsInput().withRecords(EMPTY_RECORD_LIST));
}
@Test
public void waitForStatusMessageSuccessTest() {
when(messageWriter.writeProcessRecordsMessage(any(ProcessRecordsInput.class))).thenReturn(buildFuture(true));
when(messageReader.getNextMessageFromSTDOUT()).thenReturn(buildFuture(new StatusMessage("processRecords"), Message.class));
when(configuration.getTimeoutEnabled()).thenReturn(Optional.of(true));
when(configuration.getTimeoutInSeconds()).thenReturn(Optional.of(5));
assertTrue(protocol.processRecords(new ProcessRecordsInput().withRecords(EMPTY_RECORD_LIST)));
}
} }

View file

@ -14,14 +14,16 @@
*/ */
package com.amazonaws.services.kinesis.multilang; package com.amazonaws.services.kinesis.multilang;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor;
import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
@RunWith(MockitoJUnitRunner.class)
public class StreamingRecordProcessorFactoryTest { public class StreamingRecordProcessorFactoryTest {
@Mock @Mock

View file

@ -35,6 +35,7 @@ import java.util.concurrent.Future;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason;
import com.google.common.base.Optional;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -125,7 +126,7 @@ public class StreamingRecordProcessorTest {
messageWriter = Mockito.mock(MessageWriter.class); messageWriter = Mockito.mock(MessageWriter.class);
messageReader = Mockito.mock(MessageReader.class); messageReader = Mockito.mock(MessageReader.class);
errorReader = Mockito.mock(DrainChildSTDERRTask.class); errorReader = Mockito.mock(DrainChildSTDERRTask.class);
when(configuration.isTimeoutEnabled()).thenReturn(false); when(configuration.getTimeoutEnabled()).thenReturn(Optional.of(false));
recordProcessor = recordProcessor =
new MultiLangRecordProcessor(new ProcessBuilder(), executor, new ObjectMapper(), messageWriter, new MultiLangRecordProcessor(new ProcessBuilder(), executor, new ObjectMapper(), messageWriter,
@ -171,7 +172,7 @@ public class StreamingRecordProcessorTest {
*/ */
when(messageFuture.get()).thenAnswer(answer); when(messageFuture.get()).thenAnswer(answer);
when(messageReader.getNextMessageFromSTDOUT()).thenReturn(messageFuture); when(messageReader.getNextMessageFromSTDOUT()).thenReturn(messageFuture);
when(configuration.isTimeoutEnabled()).thenReturn(false); when(configuration.getTimeoutEnabled()).thenReturn(Optional.of(false));
List<Record> testRecords = new ArrayList<Record>(); List<Record> testRecords = new ArrayList<Record>();