Merge pull request #97 from pfifer/shard-info-equals

Don't Use Checkpoint for Equality in ShardInfo
This commit is contained in:
Justin Pfifer 2016-08-17 14:59:43 -07:00 committed by GitHub
commit a02022fb0f
5 changed files with 240 additions and 176 deletions

View file

@ -19,13 +19,17 @@ import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.lang.builder.EqualsBuilder;
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
import org.apache.commons.lang.builder.HashCodeBuilder;
/**
* Used to pass shard related info among different classes and as a key to the map of shard consumers.
*/
class ShardInfo {
private final String shardId;
private final String concurrencyToken;
// Sorted list of parent shardIds.
@ -33,10 +37,16 @@ class ShardInfo {
private final ExtendedSequenceNumber checkpoint;
/**
* @param shardId Kinesis shardId
* @param concurrencyToken Used to differentiate between lost and reclaimed leases
* @param parentShardIds Parent shards of the shard identified by Kinesis shardId
* @param checkpoint the latest checkpoint from lease
* Creates a new ShardInfo object. The checkpoint is not part of the equality, but is used for debugging output.
*
* @param shardId
* Kinesis shardId
* @param concurrencyToken
* Used to differentiate between lost and reclaimed leases
* @param parentShardIds
* Parent shards of the shard identified by Kinesis shardId
* @param checkpoint
* the latest checkpoint from lease
*/
public ShardInfo(String shardId,
String concurrencyToken,
@ -87,20 +97,12 @@ class ShardInfo {
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((concurrencyToken == null) ? 0 : concurrencyToken.hashCode());
result = prime * result + ((parentShardIds == null) ? 0 : parentShardIds.hashCode());
result = prime * result + ((shardId == null) ? 0 : shardId.hashCode());
result = prime * result + ((checkpoint == null) ? 0 : checkpoint.hashCode());
return result;
return new HashCodeBuilder().append(concurrencyToken).append(parentShardIds).append(shardId).toHashCode();
}
/**
* {@inheritDoc}
*/
// CHECKSTYLE:OFF CyclomaticComplexity
// CHECKSTYLE:OFF NPathComplexity
/**
* This method assumes parentShardIds is ordered. The Worker.cleanupShardConsumers() method relies on this method
* returning true for ShardInfo objects which may have been instantiated with parentShardIds in a different order
@ -121,39 +123,11 @@ class ShardInfo {
return false;
}
ShardInfo other = (ShardInfo) obj;
if (concurrencyToken == null) {
if (other.concurrencyToken != null) {
return false;
}
} else if (!concurrencyToken.equals(other.concurrencyToken)) {
return false;
}
if (parentShardIds == null) {
if (other.parentShardIds != null) {
return false;
}
} else if (!parentShardIds.equals(other.parentShardIds)) {
return false;
}
if (shardId == null) {
if (other.shardId != null) {
return false;
}
} else if (!shardId.equals(other.shardId)) {
return false;
}
if (checkpoint == null) {
if (other.checkpoint != null) {
return false;
}
} else if (!checkpoint.equals(other.checkpoint)) {
return false;
}
return true;
return new EqualsBuilder().append(concurrencyToken, other.concurrencyToken)
.append(parentShardIds, other.parentShardIds).append(shardId, other.shardId).isEquals();
}
// CHECKSTYLE:ON CyclomaticComplexity
// CHECKSTYLE:ON NPathComplexity
@Override
public String toString() {
@ -161,41 +135,6 @@ class ShardInfo {
+ parentShardIds + ", checkpoint=" + checkpoint + "]";
}
/**
* Builder class for ShardInfo.
*/
public static class Builder {
private String shardId;
private String concurrencyToken;
private List<String> parentShardIds = Collections.emptyList();
private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST;
public Builder() {
}
public Builder withShardId(String shardId) {
this.shardId = shardId;
return this;
}
public Builder withConcurrencyToken(String concurrencyToken) {
this.concurrencyToken = concurrencyToken;
return this;
}
public Builder withParentShards(List<String> parentShardIds) {
this.parentShardIds = parentShardIds;
return this;
}
public Builder withCheckpoint(ExtendedSequenceNumber checkpoint) {
this.checkpoint = checkpoint;
return this;
}
public ShardInfo build() {
return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint);
}
}
}

View file

@ -38,7 +38,6 @@ import com.amazonaws.services.kinesis.AmazonKinesisClient;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.ICheckpoint;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessorFactory;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker.Builder;
import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxyFactory;
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason;
import com.amazonaws.services.kinesis.leases.exceptions.LeasingException;
@ -47,6 +46,7 @@ import com.amazonaws.services.kinesis.metrics.impl.CWMetricsFactory;
import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory;
import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory;
import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel;
import com.google.common.annotations.VisibleForTesting;
/**
* Worker is the high level class that Kinesis applications use to start
@ -342,13 +342,21 @@ public class Worker implements Runnable {
}
while (!shouldShutdown()) {
runProcessLoop();
}
finalShutdown();
LOG.info("Worker loop is complete. Exiting from worker.");
}
@VisibleForTesting
void runProcessLoop() {
try {
boolean foundCompletedShard = false;
Set<ShardInfo> assignedShards = new HashSet<ShardInfo>();
Set<ShardInfo> assignedShards = new HashSet<>();
for (ShardInfo shardInfo : getShardInfoForAssignments()) {
ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory);
if (shardConsumer.isShutdown()
&& shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) {
if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) {
foundCompletedShard = true;
} else {
shardConsumer.consumeShard();
@ -367,8 +375,7 @@ public class Worker implements Runnable {
Thread.sleep(idleTimeInMilliseconds);
} catch (Exception e) {
LOG.error(String.format("Worker.run caught exception, sleeping for %s milli seconds!",
String.valueOf(idleTimeInMilliseconds)),
e);
String.valueOf(idleTimeInMilliseconds)), e);
try {
Thread.sleep(idleTimeInMilliseconds);
} catch (InterruptedException ex) {
@ -378,10 +385,6 @@ public class Worker implements Runnable {
wlog.resetInfoLogging();
}
finalShutdown();
LOG.info("Worker loop is complete. Exiting from worker.");
}
private void initialize() {
boolean isDone = false;
Exception lastException = null;
@ -552,25 +555,22 @@ public class Worker implements Runnable {
// completely processed (shutdown reason terminate).
if ((consumer == null)
|| (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) {
IRecordProcessor recordProcessor = factory.createProcessor();
consumer =
new ShardConsumer(shardInfo,
streamConfig,
checkpointTracker,
recordProcessor,
leaseCoordinator.getLeaseManager(),
parentShardPollIntervalMillis,
cleanupLeasesUponShardCompletion,
executorService,
metricsFactory,
taskBackoffTimeMillis);
consumer = buildConsumer(shardInfo, factory);
shardInfoShardConsumerMap.put(shardInfo, consumer);
wlog.infoForce("Created new shardConsumer for : " + shardInfo);
}
return consumer;
}
protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory factory) {
IRecordProcessor recordProcessor = factory.createProcessor();
return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor,
leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion,
executorService, metricsFactory, taskBackoffTimeMillis);
}
/**
* Logger for suppressing too much INFO logging. To avoid too much logging
* information Worker will output logging at INFO level for a single pass

View file

@ -11,6 +11,8 @@ import java.util.Random;
import org.junit.Test;
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
public class ParentsFirstShardPrioritizationUnitTest {
@Test(expected = IllegalArgumentException.class)
@ -144,17 +146,54 @@ public class ParentsFirstShardPrioritizationUnitTest {
return "shardId-" + shardNumber;
}
/**
* Builder class for ShardInfo.
*/
static class ShardInfoBuilder {
private String shardId;
private String concurrencyToken;
private List<String> parentShardIds = Collections.emptyList();
private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST;
ShardInfoBuilder() {
}
ShardInfoBuilder withShardId(String shardId) {
this.shardId = shardId;
return this;
}
ShardInfoBuilder withConcurrencyToken(String concurrencyToken) {
this.concurrencyToken = concurrencyToken;
return this;
}
ShardInfoBuilder withParentShards(List<String> parentShardIds) {
this.parentShardIds = parentShardIds;
return this;
}
ShardInfoBuilder withCheckpoint(ExtendedSequenceNumber checkpoint) {
this.checkpoint = checkpoint;
return this;
}
ShardInfo build() {
return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint);
}
}
private static ShardInfo shardInfo(String shardId, List<String> parentShardIds) {
// copy into new list just in case ShardInfo will stop doing it
List<String> newParentShardIds = new ArrayList<>(parentShardIds);
return new ShardInfo.Builder()
return new ShardInfoBuilder()
.withShardId(shardId)
.withParentShards(newParentShardIds)
.build();
}
private static ShardInfo shardInfo(String shardId, String... parentShardIds) {
return new ShardInfo.Builder()
return new ShardInfoBuilder()
.withShardId(shardId)
.withParentShards(Arrays.asList(parentShardIds))
.build();

View file

@ -14,6 +14,10 @@
*/
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@ -93,11 +97,20 @@ public class ShardInfoTest {
}
@Test
public void testPacboyShardInfoEqualsForCheckpoint() {
ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.SHARD_END);
Assert.assertFalse("Equal should return false with different checkpoint", diffShardInfo.equals(testShardInfo));
diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null);
Assert.assertFalse("Equal should return false with null checkpoint", diffShardInfo.equals(testShardInfo));
public void testShardInfoCheckpointEqualsHashCode() {
ShardInfo baseInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds,
ExtendedSequenceNumber.TRIM_HORIZON);
ShardInfo differentCheckpoint = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds,
new ExtendedSequenceNumber("1234"));
ShardInfo nullCheckpoint = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null);
assertThat("Checkpoint should not be included in equality.", baseInfo.equals(differentCheckpoint), is(true));
assertThat("Checkpoint should not be included in equality.", baseInfo.equals(nullCheckpoint), is(true));
assertThat("Checkpoint should not be included in hash code.", baseInfo.hashCode(),
equalTo(differentCheckpoint.hashCode()));
assertThat("Checkpoint should not be included in hash code.", baseInfo.hashCode(),
equalTo(nullCheckpoint.hashCode()));
}
@Test

View file

@ -16,10 +16,16 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -50,7 +56,10 @@ import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDB;
@ -85,6 +94,7 @@ import com.amazonaws.services.kinesis.model.Shard;
/**
* Unit tests of Worker.
*/
@RunWith(MockitoJUnitRunner.class)
public class WorkerTest {
private static final Log LOG = LogFactory.getLog(WorkerTest.class);
@ -107,6 +117,30 @@ public class WorkerTest {
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON);
private final ShardPrioritization shardPrioritization = new NoOpShardPrioritization();
private static final String KINESIS_SHARD_ID_FORMAT = "kinesis-0-0-%d";
private static final String CONCURRENCY_TOKEN_FORMAT = "testToken-%d";
@Mock
private KinesisClientLibLeaseCoordinator leaseCoordinator;
@Mock
private ILeaseManager<KinesisClientLease> leaseManager;
@Mock
private com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory v1RecordProcessorFactory;
@Mock
private IKinesisProxy proxy;
@Mock
private WorkerThreadPoolExecutor executorService;
@Mock
private WorkerCWMetricsFactory cwMetricsFactory;
@Mock
private IKinesisProxy kinesisProxy;
@Mock
private IRecordProcessorFactory v2RecordProcessorFactory;
@Mock
private IRecordProcessor v2RecordProcessor;
@Mock
private ShardConsumer shardConsumer;
// CHECKSTYLE:IGNORE AnonInnerLengthCheck FOR NEXT 50 LINES
private static final com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY =
new com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory() {
@ -145,6 +179,7 @@ public class WorkerTest {
private static final IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY_V2 =
new V1ToV2RecordProcessorFactoryAdapter(SAMPLE_RECORD_PROCESSOR_FACTORY);
/**
* Test method for {@link com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker#getApplicationName()}.
*/
@ -153,9 +188,7 @@ public class WorkerTest {
final String stageName = "testStageName";
final KinesisClientLibConfiguration clientConfig =
new KinesisClientLibConfiguration(stageName, null, null, null);
Worker worker =
new Worker(mock(com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory.class),
clientConfig);
Worker worker = new Worker(v1RecordProcessorFactory, clientConfig);
Assert.assertEquals(stageName, worker.getApplicationName());
}
@ -177,9 +210,6 @@ public class WorkerTest {
final String dummyKinesisShardId = "kinesis-0-0";
ExecutorService execService = null;
KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class);
@SuppressWarnings("unchecked")
ILeaseManager<KinesisClientLease> leaseManager = mock(ILeaseManager.class);
when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager);
Worker worker =
@ -208,6 +238,63 @@ public class WorkerTest {
Assert.assertNotSame(consumer3, consumer);
}
@Test
public void testWorkerLoopWithCheckpoint() {
final String stageName = "testStageName";
IRecordProcessorFactory streamletFactory = SAMPLE_RECORD_PROCESSOR_FACTORY_V2;
IKinesisProxy proxy = null;
ICheckpoint checkpoint = null;
int maxRecords = 1;
int idleTimeInMilliseconds = 1000;
StreamConfig streamConfig = new StreamConfig(proxy, maxRecords, idleTimeInMilliseconds,
callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST);
ExecutorService execService = null;
when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager);
List<ShardInfo> initialState = createShardInfoList(ExtendedSequenceNumber.TRIM_HORIZON);
List<ShardInfo> firstCheckpoint = createShardInfoList(new ExtendedSequenceNumber("1000"));
List<ShardInfo> secondCheckpoint = createShardInfoList(new ExtendedSequenceNumber("2000"));
when(leaseCoordinator.getCurrentAssignments()).thenReturn(initialState).thenReturn(firstCheckpoint)
.thenReturn(secondCheckpoint);
Worker worker = new Worker(stageName, streamletFactory, streamConfig, INITIAL_POSITION_LATEST,
parentShardPollIntervalMillis, shardSyncIntervalMillis, cleanupLeasesUponShardCompletion, checkpoint,
leaseCoordinator, execService, nullMetricsFactory, taskBackoffTimeMillis, failoverTimeMillis,
shardPrioritization);
Worker workerSpy = spy(worker);
doReturn(shardConsumer).when(workerSpy).buildConsumer(eq(initialState.get(0)), any(IRecordProcessorFactory.class));
workerSpy.runProcessLoop();
workerSpy.runProcessLoop();
workerSpy.runProcessLoop();
verify(workerSpy).buildConsumer(same(initialState.get(0)), any(IRecordProcessorFactory.class));
verify(workerSpy, never()).buildConsumer(same(firstCheckpoint.get(0)), any(IRecordProcessorFactory.class));
verify(workerSpy, never()).buildConsumer(same(secondCheckpoint.get(0)), any(IRecordProcessorFactory.class));
}
private List<ShardInfo> createShardInfoList(ExtendedSequenceNumber... sequenceNumbers) {
List<ShardInfo> result = new ArrayList<>(sequenceNumbers.length);
assertThat(sequenceNumbers.length, greaterThanOrEqualTo(1));
for (int i = 0; i < sequenceNumbers.length; ++i) {
result.add(new ShardInfo(adjustedShardId(i), adjustedConcurrencyToken(i), null, sequenceNumbers[i]));
}
return result;
}
private String adjustedShardId(int index) {
return String.format(KINESIS_SHARD_ID_FORMAT, index);
}
private String adjustedConcurrencyToken(int index) {
return String.format(CONCURRENCY_TOKEN_FORMAT, index);
}
@Test
public final void testCleanupShardConsumers() {
final String stageName = "testStageName";
@ -226,10 +313,6 @@ public class WorkerTest {
final String dummyKinesisShardId = "kinesis-0-0";
final String anotherDummyKinesisShardId = "kinesis-0-1";
ExecutorService execService = null;
KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class);
@SuppressWarnings("unchecked")
ILeaseManager<KinesisClientLease> leaseManager = mock(ILeaseManager.class);
when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager);
Worker worker =
@ -272,7 +355,6 @@ public class WorkerTest {
public final void testInitializationFailureWithRetries() {
String stageName = "testInitializationWorker";
IRecordProcessorFactory recordProcessorFactory = new TestStreamletFactory(null, null);
IKinesisProxy proxy = mock(IKinesisProxy.class);
int count = 0;
when(proxy.getShardList()).thenThrow(new RuntimeException(Integer.toString(count++)));
int maxRecords = 2;
@ -282,9 +364,6 @@ public class WorkerTest {
maxRecords,
idleTimeInMilliseconds,
callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST);
KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class);
@SuppressWarnings("unchecked")
ILeaseManager<KinesisClientLease> leaseManager = mock(ILeaseManager.class);
when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager);
ExecutorService execService = Executors.newSingleThreadExecutor();
long shardPollInterval = 0L;
@ -374,8 +453,7 @@ public class WorkerTest {
@Test
public final void testWorkerShutsDownOwnedResources() throws Exception {
final WorkerThreadPoolExecutor executorService = mock(WorkerThreadPoolExecutor.class);
final WorkerCWMetricsFactory cwMetricsFactory = mock(WorkerCWMetricsFactory.class);
final long failoverTimeMillis = 20L;
// Make sure that worker thread is run before invoking shutdown.
@ -393,8 +471,7 @@ public class WorkerTest {
callProcessRecordsForEmptyRecordList,
failoverTimeMillis,
10,
mock(IKinesisProxy.class),
mock(IRecordProcessorFactory.class),
kinesisProxy, v2RecordProcessorFactory,
executorService,
cwMetricsFactory);
@ -411,10 +488,10 @@ public class WorkerTest {
@Test
public final void testWorkerDoesNotShutdownClientResources() throws Exception {
final ExecutorService executorService = mock(ThreadPoolExecutor.class);
final CWMetricsFactory cwMetricsFactory = mock(CWMetricsFactory.class);
final long failoverTimeMillis = 20L;
final ExecutorService executorService = mock(ThreadPoolExecutor.class);
final CWMetricsFactory cwMetricsFactory = mock(CWMetricsFactory.class);
// Make sure that worker thread is run before invoking shutdown.
final CountDownLatch workerStarted = new CountDownLatch(1);
doAnswer(new Answer<Boolean>() {
@ -430,8 +507,7 @@ public class WorkerTest {
callProcessRecordsForEmptyRecordList,
failoverTimeMillis,
10,
mock(IKinesisProxy.class),
mock(IRecordProcessorFactory.class),
kinesisProxy, v2RecordProcessorFactory,
executorService,
cwMetricsFactory);
@ -468,9 +544,8 @@ public class WorkerTest {
// Make test case as efficient as possible.
final CountDownLatch processRecordsLatch = new CountDownLatch(1);
IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class);
IRecordProcessor recordProcessor = mock(IRecordProcessor.class);
when(recordProcessorFactory.createProcessor()).thenReturn(recordProcessor);
when(v2RecordProcessorFactory.createProcessor()).thenReturn(v2RecordProcessor);
doAnswer(new Answer<Object> () {
@Override
@ -479,7 +554,7 @@ public class WorkerTest {
processRecordsLatch.countDown();
return null;
}
}).when(recordProcessor).processRecords(any(ProcessRecordsInput.class));
}).when(v2RecordProcessor).processRecords(any(ProcessRecordsInput.class));
WorkerThread workerThread = runWorker(shardList,
initialLeases,
@ -487,7 +562,7 @@ public class WorkerTest {
failoverTimeMillis,
numberOfRecordsPerShard,
fileBasedProxy,
recordProcessorFactory,
v2RecordProcessorFactory,
executorService,
nullMetricsFactory);
@ -495,16 +570,16 @@ public class WorkerTest {
processRecordsLatch.await();
// Make sure record processor is initialized and processing records.
verify(recordProcessorFactory, times(1)).createProcessor();
verify(recordProcessor, times(1)).initialize(any(InitializationInput.class));
verify(recordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class));
verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class));
verify(v2RecordProcessorFactory, times(1)).createProcessor();
verify(v2RecordProcessor, times(1)).initialize(any(InitializationInput.class));
verify(v2RecordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class));
verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class));
workerThread.getWorker().shutdown();
workerThread.join();
Assert.assertTrue(workerThread.getState() == State.TERMINATED);
verify(recordProcessor, times(1)).shutdown(any(ShutdownInput.class));
verify(v2RecordProcessor, times(1)).shutdown(any(ShutdownInput.class));
}
/**
@ -538,9 +613,7 @@ public class WorkerTest {
// Make test case as efficient as possible.
final CountDownLatch processRecordsLatch = new CountDownLatch(1);
final AtomicBoolean recordProcessorInterrupted = new AtomicBoolean(false);
IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class);
IRecordProcessor recordProcessor = mock(IRecordProcessor.class);
when(recordProcessorFactory.createProcessor()).thenReturn(recordProcessor);
when(v2RecordProcessorFactory.createProcessor()).thenReturn(v2RecordProcessor);
final Semaphore actionBlocker = new Semaphore(1);
final Semaphore shutdownBlocker = new Semaphore(1);
@ -572,7 +645,7 @@ public class WorkerTest {
return null;
}
}).when(recordProcessor).processRecords(any(ProcessRecordsInput.class));
}).when(v2RecordProcessor).processRecords(any(ProcessRecordsInput.class));
WorkerThread workerThread = runWorker(shardList,
initialLeases,
@ -580,7 +653,7 @@ public class WorkerTest {
failoverTimeMillis,
numberOfRecordsPerShard,
fileBasedProxy,
recordProcessorFactory,
v2RecordProcessorFactory,
executorService,
nullMetricsFactory);
@ -588,17 +661,17 @@ public class WorkerTest {
processRecordsLatch.await();
// Make sure record processor is initialized and processing records.
verify(recordProcessorFactory, times(1)).createProcessor();
verify(recordProcessor, times(1)).initialize(any(InitializationInput.class));
verify(recordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class));
verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class));
verify(v2RecordProcessorFactory, times(1)).createProcessor();
verify(v2RecordProcessor, times(1)).initialize(any(InitializationInput.class));
verify(v2RecordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class));
verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class));
workerThread.getWorker().shutdown();
workerThread.join();
Assert.assertTrue(workerThread.getState() == State.TERMINATED);
// Shutdown should not be called in this case because record processor is blocked.
verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class));
verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class));
//
// Release the worker thread