Merge pull request #97 from pfifer/shard-info-equals

Don't Use Checkpoint for Equality in ShardInfo
This commit is contained in:
Justin Pfifer 2016-08-17 14:59:43 -07:00 committed by GitHub
commit a02022fb0f
5 changed files with 240 additions and 176 deletions

View file

@ -19,13 +19,17 @@ import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import org.apache.commons.lang.builder.EqualsBuilder;
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
import org.apache.commons.lang.builder.HashCodeBuilder;
/** /**
* Used to pass shard related info among different classes and as a key to the map of shard consumers. * Used to pass shard related info among different classes and as a key to the map of shard consumers.
*/ */
class ShardInfo { class ShardInfo {
private final String shardId; private final String shardId;
private final String concurrencyToken; private final String concurrencyToken;
// Sorted list of parent shardIds. // Sorted list of parent shardIds.
@ -33,10 +37,16 @@ class ShardInfo {
private final ExtendedSequenceNumber checkpoint; private final ExtendedSequenceNumber checkpoint;
/** /**
* @param shardId Kinesis shardId * Creates a new ShardInfo object. The checkpoint is not part of the equality, but is used for debugging output.
* @param concurrencyToken Used to differentiate between lost and reclaimed leases *
* @param parentShardIds Parent shards of the shard identified by Kinesis shardId * @param shardId
* @param checkpoint the latest checkpoint from lease * Kinesis shardId
* @param concurrencyToken
* Used to differentiate between lost and reclaimed leases
* @param parentShardIds
* Parent shards of the shard identified by Kinesis shardId
* @param checkpoint
* the latest checkpoint from lease
*/ */
public ShardInfo(String shardId, public ShardInfo(String shardId,
String concurrencyToken, String concurrencyToken,
@ -87,20 +97,12 @@ class ShardInfo {
*/ */
@Override @Override
public int hashCode() { public int hashCode() {
final int prime = 31; return new HashCodeBuilder().append(concurrencyToken).append(parentShardIds).append(shardId).toHashCode();
int result = 1;
result = prime * result + ((concurrencyToken == null) ? 0 : concurrencyToken.hashCode());
result = prime * result + ((parentShardIds == null) ? 0 : parentShardIds.hashCode());
result = prime * result + ((shardId == null) ? 0 : shardId.hashCode());
result = prime * result + ((checkpoint == null) ? 0 : checkpoint.hashCode());
return result;
} }
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
// CHECKSTYLE:OFF CyclomaticComplexity
// CHECKSTYLE:OFF NPathComplexity
/** /**
* This method assumes parentShardIds is ordered. The Worker.cleanupShardConsumers() method relies on this method * This method assumes parentShardIds is ordered. The Worker.cleanupShardConsumers() method relies on this method
* returning true for ShardInfo objects which may have been instantiated with parentShardIds in a different order * returning true for ShardInfo objects which may have been instantiated with parentShardIds in a different order
@ -121,39 +123,11 @@ class ShardInfo {
return false; return false;
} }
ShardInfo other = (ShardInfo) obj; ShardInfo other = (ShardInfo) obj;
if (concurrencyToken == null) { return new EqualsBuilder().append(concurrencyToken, other.concurrencyToken)
if (other.concurrencyToken != null) { .append(parentShardIds, other.parentShardIds).append(shardId, other.shardId).isEquals();
return false;
}
} else if (!concurrencyToken.equals(other.concurrencyToken)) {
return false;
}
if (parentShardIds == null) {
if (other.parentShardIds != null) {
return false;
}
} else if (!parentShardIds.equals(other.parentShardIds)) {
return false;
}
if (shardId == null) {
if (other.shardId != null) {
return false;
}
} else if (!shardId.equals(other.shardId)) {
return false;
}
if (checkpoint == null) {
if (other.checkpoint != null) {
return false;
}
} else if (!checkpoint.equals(other.checkpoint)) {
return false;
}
return true;
} }
// CHECKSTYLE:ON CyclomaticComplexity
// CHECKSTYLE:ON NPathComplexity
@Override @Override
public String toString() { public String toString() {
@ -161,41 +135,6 @@ class ShardInfo {
+ parentShardIds + ", checkpoint=" + checkpoint + "]"; + parentShardIds + ", checkpoint=" + checkpoint + "]";
} }
/**
* Builder class for ShardInfo.
*/
public static class Builder {
private String shardId;
private String concurrencyToken;
private List<String> parentShardIds = Collections.emptyList();
private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST;
public Builder() {
}
public Builder withShardId(String shardId) {
this.shardId = shardId;
return this;
}
public Builder withConcurrencyToken(String concurrencyToken) {
this.concurrencyToken = concurrencyToken;
return this;
}
public Builder withParentShards(List<String> parentShardIds) {
this.parentShardIds = parentShardIds;
return this;
}
public Builder withCheckpoint(ExtendedSequenceNumber checkpoint) {
this.checkpoint = checkpoint;
return this;
}
public ShardInfo build() {
return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint);
}
}
} }

View file

@ -38,7 +38,6 @@ import com.amazonaws.services.kinesis.AmazonKinesisClient;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.ICheckpoint; import com.amazonaws.services.kinesis.clientlibrary.interfaces.ICheckpoint;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessorFactory; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessorFactory;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker.Builder;
import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxyFactory; import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxyFactory;
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason; import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason;
import com.amazonaws.services.kinesis.leases.exceptions.LeasingException; import com.amazonaws.services.kinesis.leases.exceptions.LeasingException;
@ -47,6 +46,7 @@ import com.amazonaws.services.kinesis.metrics.impl.CWMetricsFactory;
import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory; import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory;
import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory;
import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel; import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel;
import com.google.common.annotations.VisibleForTesting;
/** /**
* Worker is the high level class that Kinesis applications use to start * Worker is the high level class that Kinesis applications use to start
@ -342,46 +342,49 @@ public class Worker implements Runnable {
} }
while (!shouldShutdown()) { while (!shouldShutdown()) {
try { runProcessLoop();
boolean foundCompletedShard = false;
Set<ShardInfo> assignedShards = new HashSet<ShardInfo>();
for (ShardInfo shardInfo : getShardInfoForAssignments()) {
ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory);
if (shardConsumer.isShutdown()
&& shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) {
foundCompletedShard = true;
} else {
shardConsumer.consumeShard();
}
assignedShards.add(shardInfo);
}
if (foundCompletedShard) {
controlServer.syncShardAndLeaseInfo(null);
}
// clean up shard consumers for unassigned shards
cleanupShardConsumers(assignedShards);
wlog.info("Sleeping ...");
Thread.sleep(idleTimeInMilliseconds);
} catch (Exception e) {
LOG.error(String.format("Worker.run caught exception, sleeping for %s milli seconds!",
String.valueOf(idleTimeInMilliseconds)),
e);
try {
Thread.sleep(idleTimeInMilliseconds);
} catch (InterruptedException ex) {
LOG.info("Worker: sleep interrupted after catching exception ", ex);
}
}
wlog.resetInfoLogging();
} }
finalShutdown(); finalShutdown();
LOG.info("Worker loop is complete. Exiting from worker."); LOG.info("Worker loop is complete. Exiting from worker.");
} }
@VisibleForTesting
void runProcessLoop() {
try {
boolean foundCompletedShard = false;
Set<ShardInfo> assignedShards = new HashSet<>();
for (ShardInfo shardInfo : getShardInfoForAssignments()) {
ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory);
if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) {
foundCompletedShard = true;
} else {
shardConsumer.consumeShard();
}
assignedShards.add(shardInfo);
}
if (foundCompletedShard) {
controlServer.syncShardAndLeaseInfo(null);
}
// clean up shard consumers for unassigned shards
cleanupShardConsumers(assignedShards);
wlog.info("Sleeping ...");
Thread.sleep(idleTimeInMilliseconds);
} catch (Exception e) {
LOG.error(String.format("Worker.run caught exception, sleeping for %s milli seconds!",
String.valueOf(idleTimeInMilliseconds)), e);
try {
Thread.sleep(idleTimeInMilliseconds);
} catch (InterruptedException ex) {
LOG.info("Worker: sleep interrupted after catching exception ", ex);
}
}
wlog.resetInfoLogging();
}
private void initialize() { private void initialize() {
boolean isDone = false; boolean isDone = false;
Exception lastException = null; Exception lastException = null;
@ -552,25 +555,22 @@ public class Worker implements Runnable {
// completely processed (shutdown reason terminate). // completely processed (shutdown reason terminate).
if ((consumer == null) if ((consumer == null)
|| (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) { || (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) {
IRecordProcessor recordProcessor = factory.createProcessor(); consumer = buildConsumer(shardInfo, factory);
consumer =
new ShardConsumer(shardInfo,
streamConfig,
checkpointTracker,
recordProcessor,
leaseCoordinator.getLeaseManager(),
parentShardPollIntervalMillis,
cleanupLeasesUponShardCompletion,
executorService,
metricsFactory,
taskBackoffTimeMillis);
shardInfoShardConsumerMap.put(shardInfo, consumer); shardInfoShardConsumerMap.put(shardInfo, consumer);
wlog.infoForce("Created new shardConsumer for : " + shardInfo); wlog.infoForce("Created new shardConsumer for : " + shardInfo);
} }
return consumer; return consumer;
} }
protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory factory) {
IRecordProcessor recordProcessor = factory.createProcessor();
return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor,
leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion,
executorService, metricsFactory, taskBackoffTimeMillis);
}
/** /**
* Logger for suppressing too much INFO logging. To avoid too much logging * Logger for suppressing too much INFO logging. To avoid too much logging
* information Worker will output logging at INFO level for a single pass * information Worker will output logging at INFO level for a single pass

View file

@ -11,6 +11,8 @@ import java.util.Random;
import org.junit.Test; import org.junit.Test;
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
public class ParentsFirstShardPrioritizationUnitTest { public class ParentsFirstShardPrioritizationUnitTest {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
@ -144,17 +146,54 @@ public class ParentsFirstShardPrioritizationUnitTest {
return "shardId-" + shardNumber; return "shardId-" + shardNumber;
} }
/**
* Builder class for ShardInfo.
*/
static class ShardInfoBuilder {
private String shardId;
private String concurrencyToken;
private List<String> parentShardIds = Collections.emptyList();
private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST;
ShardInfoBuilder() {
}
ShardInfoBuilder withShardId(String shardId) {
this.shardId = shardId;
return this;
}
ShardInfoBuilder withConcurrencyToken(String concurrencyToken) {
this.concurrencyToken = concurrencyToken;
return this;
}
ShardInfoBuilder withParentShards(List<String> parentShardIds) {
this.parentShardIds = parentShardIds;
return this;
}
ShardInfoBuilder withCheckpoint(ExtendedSequenceNumber checkpoint) {
this.checkpoint = checkpoint;
return this;
}
ShardInfo build() {
return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint);
}
}
private static ShardInfo shardInfo(String shardId, List<String> parentShardIds) { private static ShardInfo shardInfo(String shardId, List<String> parentShardIds) {
// copy into new list just in case ShardInfo will stop doing it // copy into new list just in case ShardInfo will stop doing it
List<String> newParentShardIds = new ArrayList<>(parentShardIds); List<String> newParentShardIds = new ArrayList<>(parentShardIds);
return new ShardInfo.Builder() return new ShardInfoBuilder()
.withShardId(shardId) .withShardId(shardId)
.withParentShards(newParentShardIds) .withParentShards(newParentShardIds)
.build(); .build();
} }
private static ShardInfo shardInfo(String shardId, String... parentShardIds) { private static ShardInfo shardInfo(String shardId, String... parentShardIds) {
return new ShardInfo.Builder() return new ShardInfoBuilder()
.withShardId(shardId) .withShardId(shardId)
.withParentShards(Arrays.asList(parentShardIds)) .withParentShards(Arrays.asList(parentShardIds))
.build(); .build();

View file

@ -14,6 +14,10 @@
*/ */
package com.amazonaws.services.kinesis.clientlibrary.lib.worker; package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
@ -93,11 +97,20 @@ public class ShardInfoTest {
} }
@Test @Test
public void testPacboyShardInfoEqualsForCheckpoint() { public void testShardInfoCheckpointEqualsHashCode() {
ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.SHARD_END); ShardInfo baseInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds,
Assert.assertFalse("Equal should return false with different checkpoint", diffShardInfo.equals(testShardInfo)); ExtendedSequenceNumber.TRIM_HORIZON);
diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null); ShardInfo differentCheckpoint = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds,
Assert.assertFalse("Equal should return false with null checkpoint", diffShardInfo.equals(testShardInfo)); new ExtendedSequenceNumber("1234"));
ShardInfo nullCheckpoint = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null);
assertThat("Checkpoint should not be included in equality.", baseInfo.equals(differentCheckpoint), is(true));
assertThat("Checkpoint should not be included in equality.", baseInfo.equals(nullCheckpoint), is(true));
assertThat("Checkpoint should not be included in hash code.", baseInfo.hashCode(),
equalTo(differentCheckpoint.hashCode()));
assertThat("Checkpoint should not be included in hash code.", baseInfo.hashCode(),
equalTo(nullCheckpoint.hashCode()));
} }
@Test @Test

View file

@ -16,10 +16,16 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -50,7 +56,10 @@ import org.junit.Assert;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.Timeout; import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; import com.amazonaws.services.dynamodbv2.AmazonDynamoDB;
@ -85,6 +94,7 @@ import com.amazonaws.services.kinesis.model.Shard;
/** /**
* Unit tests of Worker. * Unit tests of Worker.
*/ */
@RunWith(MockitoJUnitRunner.class)
public class WorkerTest { public class WorkerTest {
private static final Log LOG = LogFactory.getLog(WorkerTest.class); private static final Log LOG = LogFactory.getLog(WorkerTest.class);
@ -107,6 +117,30 @@ public class WorkerTest {
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON); InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON);
private final ShardPrioritization shardPrioritization = new NoOpShardPrioritization(); private final ShardPrioritization shardPrioritization = new NoOpShardPrioritization();
private static final String KINESIS_SHARD_ID_FORMAT = "kinesis-0-0-%d";
private static final String CONCURRENCY_TOKEN_FORMAT = "testToken-%d";
@Mock
private KinesisClientLibLeaseCoordinator leaseCoordinator;
@Mock
private ILeaseManager<KinesisClientLease> leaseManager;
@Mock
private com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory v1RecordProcessorFactory;
@Mock
private IKinesisProxy proxy;
@Mock
private WorkerThreadPoolExecutor executorService;
@Mock
private WorkerCWMetricsFactory cwMetricsFactory;
@Mock
private IKinesisProxy kinesisProxy;
@Mock
private IRecordProcessorFactory v2RecordProcessorFactory;
@Mock
private IRecordProcessor v2RecordProcessor;
@Mock
private ShardConsumer shardConsumer;
// CHECKSTYLE:IGNORE AnonInnerLengthCheck FOR NEXT 50 LINES // CHECKSTYLE:IGNORE AnonInnerLengthCheck FOR NEXT 50 LINES
private static final com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY = private static final com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY =
new com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory() { new com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory() {
@ -145,6 +179,7 @@ public class WorkerTest {
private static final IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY_V2 = private static final IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY_V2 =
new V1ToV2RecordProcessorFactoryAdapter(SAMPLE_RECORD_PROCESSOR_FACTORY); new V1ToV2RecordProcessorFactoryAdapter(SAMPLE_RECORD_PROCESSOR_FACTORY);
/** /**
* Test method for {@link com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker#getApplicationName()}. * Test method for {@link com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker#getApplicationName()}.
*/ */
@ -153,9 +188,7 @@ public class WorkerTest {
final String stageName = "testStageName"; final String stageName = "testStageName";
final KinesisClientLibConfiguration clientConfig = final KinesisClientLibConfiguration clientConfig =
new KinesisClientLibConfiguration(stageName, null, null, null); new KinesisClientLibConfiguration(stageName, null, null, null);
Worker worker = Worker worker = new Worker(v1RecordProcessorFactory, clientConfig);
new Worker(mock(com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory.class),
clientConfig);
Assert.assertEquals(stageName, worker.getApplicationName()); Assert.assertEquals(stageName, worker.getApplicationName());
} }
@ -177,9 +210,6 @@ public class WorkerTest {
final String dummyKinesisShardId = "kinesis-0-0"; final String dummyKinesisShardId = "kinesis-0-0";
ExecutorService execService = null; ExecutorService execService = null;
KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class);
@SuppressWarnings("unchecked")
ILeaseManager<KinesisClientLease> leaseManager = mock(ILeaseManager.class);
when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager);
Worker worker = Worker worker =
@ -208,6 +238,63 @@ public class WorkerTest {
Assert.assertNotSame(consumer3, consumer); Assert.assertNotSame(consumer3, consumer);
} }
@Test
public void testWorkerLoopWithCheckpoint() {
final String stageName = "testStageName";
IRecordProcessorFactory streamletFactory = SAMPLE_RECORD_PROCESSOR_FACTORY_V2;
IKinesisProxy proxy = null;
ICheckpoint checkpoint = null;
int maxRecords = 1;
int idleTimeInMilliseconds = 1000;
StreamConfig streamConfig = new StreamConfig(proxy, maxRecords, idleTimeInMilliseconds,
callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST);
ExecutorService execService = null;
when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager);
List<ShardInfo> initialState = createShardInfoList(ExtendedSequenceNumber.TRIM_HORIZON);
List<ShardInfo> firstCheckpoint = createShardInfoList(new ExtendedSequenceNumber("1000"));
List<ShardInfo> secondCheckpoint = createShardInfoList(new ExtendedSequenceNumber("2000"));
when(leaseCoordinator.getCurrentAssignments()).thenReturn(initialState).thenReturn(firstCheckpoint)
.thenReturn(secondCheckpoint);
Worker worker = new Worker(stageName, streamletFactory, streamConfig, INITIAL_POSITION_LATEST,
parentShardPollIntervalMillis, shardSyncIntervalMillis, cleanupLeasesUponShardCompletion, checkpoint,
leaseCoordinator, execService, nullMetricsFactory, taskBackoffTimeMillis, failoverTimeMillis,
shardPrioritization);
Worker workerSpy = spy(worker);
doReturn(shardConsumer).when(workerSpy).buildConsumer(eq(initialState.get(0)), any(IRecordProcessorFactory.class));
workerSpy.runProcessLoop();
workerSpy.runProcessLoop();
workerSpy.runProcessLoop();
verify(workerSpy).buildConsumer(same(initialState.get(0)), any(IRecordProcessorFactory.class));
verify(workerSpy, never()).buildConsumer(same(firstCheckpoint.get(0)), any(IRecordProcessorFactory.class));
verify(workerSpy, never()).buildConsumer(same(secondCheckpoint.get(0)), any(IRecordProcessorFactory.class));
}
private List<ShardInfo> createShardInfoList(ExtendedSequenceNumber... sequenceNumbers) {
List<ShardInfo> result = new ArrayList<>(sequenceNumbers.length);
assertThat(sequenceNumbers.length, greaterThanOrEqualTo(1));
for (int i = 0; i < sequenceNumbers.length; ++i) {
result.add(new ShardInfo(adjustedShardId(i), adjustedConcurrencyToken(i), null, sequenceNumbers[i]));
}
return result;
}
private String adjustedShardId(int index) {
return String.format(KINESIS_SHARD_ID_FORMAT, index);
}
private String adjustedConcurrencyToken(int index) {
return String.format(CONCURRENCY_TOKEN_FORMAT, index);
}
@Test @Test
public final void testCleanupShardConsumers() { public final void testCleanupShardConsumers() {
final String stageName = "testStageName"; final String stageName = "testStageName";
@ -226,10 +313,6 @@ public class WorkerTest {
final String dummyKinesisShardId = "kinesis-0-0"; final String dummyKinesisShardId = "kinesis-0-0";
final String anotherDummyKinesisShardId = "kinesis-0-1"; final String anotherDummyKinesisShardId = "kinesis-0-1";
ExecutorService execService = null; ExecutorService execService = null;
KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class);
@SuppressWarnings("unchecked")
ILeaseManager<KinesisClientLease> leaseManager = mock(ILeaseManager.class);
when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager);
Worker worker = Worker worker =
@ -272,7 +355,6 @@ public class WorkerTest {
public final void testInitializationFailureWithRetries() { public final void testInitializationFailureWithRetries() {
String stageName = "testInitializationWorker"; String stageName = "testInitializationWorker";
IRecordProcessorFactory recordProcessorFactory = new TestStreamletFactory(null, null); IRecordProcessorFactory recordProcessorFactory = new TestStreamletFactory(null, null);
IKinesisProxy proxy = mock(IKinesisProxy.class);
int count = 0; int count = 0;
when(proxy.getShardList()).thenThrow(new RuntimeException(Integer.toString(count++))); when(proxy.getShardList()).thenThrow(new RuntimeException(Integer.toString(count++)));
int maxRecords = 2; int maxRecords = 2;
@ -282,9 +364,6 @@ public class WorkerTest {
maxRecords, maxRecords,
idleTimeInMilliseconds, idleTimeInMilliseconds,
callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST); callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST);
KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class);
@SuppressWarnings("unchecked")
ILeaseManager<KinesisClientLease> leaseManager = mock(ILeaseManager.class);
when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager);
ExecutorService execService = Executors.newSingleThreadExecutor(); ExecutorService execService = Executors.newSingleThreadExecutor();
long shardPollInterval = 0L; long shardPollInterval = 0L;
@ -374,8 +453,7 @@ public class WorkerTest {
@Test @Test
public final void testWorkerShutsDownOwnedResources() throws Exception { public final void testWorkerShutsDownOwnedResources() throws Exception {
final WorkerThreadPoolExecutor executorService = mock(WorkerThreadPoolExecutor.class);
final WorkerCWMetricsFactory cwMetricsFactory = mock(WorkerCWMetricsFactory.class);
final long failoverTimeMillis = 20L; final long failoverTimeMillis = 20L;
// Make sure that worker thread is run before invoking shutdown. // Make sure that worker thread is run before invoking shutdown.
@ -393,8 +471,7 @@ public class WorkerTest {
callProcessRecordsForEmptyRecordList, callProcessRecordsForEmptyRecordList,
failoverTimeMillis, failoverTimeMillis,
10, 10,
mock(IKinesisProxy.class), kinesisProxy, v2RecordProcessorFactory,
mock(IRecordProcessorFactory.class),
executorService, executorService,
cwMetricsFactory); cwMetricsFactory);
@ -411,10 +488,10 @@ public class WorkerTest {
@Test @Test
public final void testWorkerDoesNotShutdownClientResources() throws Exception { public final void testWorkerDoesNotShutdownClientResources() throws Exception {
final ExecutorService executorService = mock(ThreadPoolExecutor.class);
final CWMetricsFactory cwMetricsFactory = mock(CWMetricsFactory.class);
final long failoverTimeMillis = 20L; final long failoverTimeMillis = 20L;
final ExecutorService executorService = mock(ThreadPoolExecutor.class);
final CWMetricsFactory cwMetricsFactory = mock(CWMetricsFactory.class);
// Make sure that worker thread is run before invoking shutdown. // Make sure that worker thread is run before invoking shutdown.
final CountDownLatch workerStarted = new CountDownLatch(1); final CountDownLatch workerStarted = new CountDownLatch(1);
doAnswer(new Answer<Boolean>() { doAnswer(new Answer<Boolean>() {
@ -430,8 +507,7 @@ public class WorkerTest {
callProcessRecordsForEmptyRecordList, callProcessRecordsForEmptyRecordList,
failoverTimeMillis, failoverTimeMillis,
10, 10,
mock(IKinesisProxy.class), kinesisProxy, v2RecordProcessorFactory,
mock(IRecordProcessorFactory.class),
executorService, executorService,
cwMetricsFactory); cwMetricsFactory);
@ -468,9 +544,8 @@ public class WorkerTest {
// Make test case as efficient as possible. // Make test case as efficient as possible.
final CountDownLatch processRecordsLatch = new CountDownLatch(1); final CountDownLatch processRecordsLatch = new CountDownLatch(1);
IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class);
IRecordProcessor recordProcessor = mock(IRecordProcessor.class); when(v2RecordProcessorFactory.createProcessor()).thenReturn(v2RecordProcessor);
when(recordProcessorFactory.createProcessor()).thenReturn(recordProcessor);
doAnswer(new Answer<Object> () { doAnswer(new Answer<Object> () {
@Override @Override
@ -479,7 +554,7 @@ public class WorkerTest {
processRecordsLatch.countDown(); processRecordsLatch.countDown();
return null; return null;
} }
}).when(recordProcessor).processRecords(any(ProcessRecordsInput.class)); }).when(v2RecordProcessor).processRecords(any(ProcessRecordsInput.class));
WorkerThread workerThread = runWorker(shardList, WorkerThread workerThread = runWorker(shardList,
initialLeases, initialLeases,
@ -487,7 +562,7 @@ public class WorkerTest {
failoverTimeMillis, failoverTimeMillis,
numberOfRecordsPerShard, numberOfRecordsPerShard,
fileBasedProxy, fileBasedProxy,
recordProcessorFactory, v2RecordProcessorFactory,
executorService, executorService,
nullMetricsFactory); nullMetricsFactory);
@ -495,16 +570,16 @@ public class WorkerTest {
processRecordsLatch.await(); processRecordsLatch.await();
// Make sure record processor is initialized and processing records. // Make sure record processor is initialized and processing records.
verify(recordProcessorFactory, times(1)).createProcessor(); verify(v2RecordProcessorFactory, times(1)).createProcessor();
verify(recordProcessor, times(1)).initialize(any(InitializationInput.class)); verify(v2RecordProcessor, times(1)).initialize(any(InitializationInput.class));
verify(recordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); verify(v2RecordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class));
verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class));
workerThread.getWorker().shutdown(); workerThread.getWorker().shutdown();
workerThread.join(); workerThread.join();
Assert.assertTrue(workerThread.getState() == State.TERMINATED); Assert.assertTrue(workerThread.getState() == State.TERMINATED);
verify(recordProcessor, times(1)).shutdown(any(ShutdownInput.class)); verify(v2RecordProcessor, times(1)).shutdown(any(ShutdownInput.class));
} }
/** /**
@ -538,9 +613,7 @@ public class WorkerTest {
// Make test case as efficient as possible. // Make test case as efficient as possible.
final CountDownLatch processRecordsLatch = new CountDownLatch(1); final CountDownLatch processRecordsLatch = new CountDownLatch(1);
final AtomicBoolean recordProcessorInterrupted = new AtomicBoolean(false); final AtomicBoolean recordProcessorInterrupted = new AtomicBoolean(false);
IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); when(v2RecordProcessorFactory.createProcessor()).thenReturn(v2RecordProcessor);
IRecordProcessor recordProcessor = mock(IRecordProcessor.class);
when(recordProcessorFactory.createProcessor()).thenReturn(recordProcessor);
final Semaphore actionBlocker = new Semaphore(1); final Semaphore actionBlocker = new Semaphore(1);
final Semaphore shutdownBlocker = new Semaphore(1); final Semaphore shutdownBlocker = new Semaphore(1);
@ -572,7 +645,7 @@ public class WorkerTest {
return null; return null;
} }
}).when(recordProcessor).processRecords(any(ProcessRecordsInput.class)); }).when(v2RecordProcessor).processRecords(any(ProcessRecordsInput.class));
WorkerThread workerThread = runWorker(shardList, WorkerThread workerThread = runWorker(shardList,
initialLeases, initialLeases,
@ -580,7 +653,7 @@ public class WorkerTest {
failoverTimeMillis, failoverTimeMillis,
numberOfRecordsPerShard, numberOfRecordsPerShard,
fileBasedProxy, fileBasedProxy,
recordProcessorFactory, v2RecordProcessorFactory,
executorService, executorService,
nullMetricsFactory); nullMetricsFactory);
@ -588,17 +661,17 @@ public class WorkerTest {
processRecordsLatch.await(); processRecordsLatch.await();
// Make sure record processor is initialized and processing records. // Make sure record processor is initialized and processing records.
verify(recordProcessorFactory, times(1)).createProcessor(); verify(v2RecordProcessorFactory, times(1)).createProcessor();
verify(recordProcessor, times(1)).initialize(any(InitializationInput.class)); verify(v2RecordProcessor, times(1)).initialize(any(InitializationInput.class));
verify(recordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); verify(v2RecordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class));
verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class));
workerThread.getWorker().shutdown(); workerThread.getWorker().shutdown();
workerThread.join(); workerThread.join();
Assert.assertTrue(workerThread.getState() == State.TERMINATED); Assert.assertTrue(workerThread.getState() == State.TERMINATED);
// Shutdown should not be called in this case because record processor is blocked. // Shutdown should not be called in this case because record processor is blocked.
verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class));
// //
// Release the worker thread // Release the worker thread