diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfo.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfo.java index 9890d02f..f04a86ba 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfo.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfo.java @@ -19,13 +19,17 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; +import org.apache.commons.lang.builder.EqualsBuilder; + import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; +import org.apache.commons.lang.builder.HashCodeBuilder; /** * Used to pass shard related info among different classes and as a key to the map of shard consumers. */ class ShardInfo { + private final String shardId; private final String concurrencyToken; // Sorted list of parent shardIds. @@ -33,10 +37,16 @@ class ShardInfo { private final ExtendedSequenceNumber checkpoint; /** - * @param shardId Kinesis shardId - * @param concurrencyToken Used to differentiate between lost and reclaimed leases - * @param parentShardIds Parent shards of the shard identified by Kinesis shardId - * @param checkpoint the latest checkpoint from lease + * Creates a new ShardInfo object. The checkpoint is not part of the equality, but is used for debugging output. + * + * @param shardId + * Kinesis shardId + * @param concurrencyToken + * Used to differentiate between lost and reclaimed leases + * @param parentShardIds + * Parent shards of the shard identified by Kinesis shardId + * @param checkpoint + * the latest checkpoint from lease */ public ShardInfo(String shardId, String concurrencyToken, @@ -87,20 +97,12 @@ class ShardInfo { */ @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((concurrencyToken == null) ? 0 : concurrencyToken.hashCode()); - result = prime * result + ((parentShardIds == null) ? 0 : parentShardIds.hashCode()); - result = prime * result + ((shardId == null) ? 0 : shardId.hashCode()); - result = prime * result + ((checkpoint == null) ? 0 : checkpoint.hashCode()); - return result; + return new HashCodeBuilder().append(concurrencyToken).append(parentShardIds).append(shardId).toHashCode(); } /** * {@inheritDoc} */ - // CHECKSTYLE:OFF CyclomaticComplexity - // CHECKSTYLE:OFF NPathComplexity /** * This method assumes parentShardIds is ordered. The Worker.cleanupShardConsumers() method relies on this method * returning true for ShardInfo objects which may have been instantiated with parentShardIds in a different order @@ -121,39 +123,11 @@ class ShardInfo { return false; } ShardInfo other = (ShardInfo) obj; - if (concurrencyToken == null) { - if (other.concurrencyToken != null) { - return false; - } - } else if (!concurrencyToken.equals(other.concurrencyToken)) { - return false; - } - if (parentShardIds == null) { - if (other.parentShardIds != null) { - return false; - } - } else if (!parentShardIds.equals(other.parentShardIds)) { - return false; - } - if (shardId == null) { - if (other.shardId != null) { - return false; - } - } else if (!shardId.equals(other.shardId)) { - return false; - } - if (checkpoint == null) { - if (other.checkpoint != null) { - return false; - } - } else if (!checkpoint.equals(other.checkpoint)) { - return false; - } - return true; + return new EqualsBuilder().append(concurrencyToken, other.concurrencyToken) + .append(parentShardIds, other.parentShardIds).append(shardId, other.shardId).isEquals(); + } - // CHECKSTYLE:ON CyclomaticComplexity - // CHECKSTYLE:ON NPathComplexity @Override public String toString() { @@ -161,41 +135,6 @@ class ShardInfo { + parentShardIds + ", checkpoint=" + checkpoint + "]"; } - /** - * Builder class for ShardInfo. - */ - public static class Builder { - private String shardId; - private String concurrencyToken; - private List parentShardIds = Collections.emptyList(); - private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST; - public Builder() { - } - - public Builder withShardId(String shardId) { - this.shardId = shardId; - return this; - } - - public Builder withConcurrencyToken(String concurrencyToken) { - this.concurrencyToken = concurrencyToken; - return this; - } - - public Builder withParentShards(List parentShardIds) { - this.parentShardIds = parentShardIds; - return this; - } - - public Builder withCheckpoint(ExtendedSequenceNumber checkpoint) { - this.checkpoint = checkpoint; - return this; - } - - public ShardInfo build() { - return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint); - } - } } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java index 3da1f2cd..32efa442 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java @@ -38,7 +38,6 @@ import com.amazonaws.services.kinesis.AmazonKinesisClient; import com.amazonaws.services.kinesis.clientlibrary.interfaces.ICheckpoint; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessorFactory; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker.Builder; import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxyFactory; import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason; import com.amazonaws.services.kinesis.leases.exceptions.LeasingException; @@ -47,6 +46,7 @@ import com.amazonaws.services.kinesis.metrics.impl.CWMetricsFactory; import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel; +import com.google.common.annotations.VisibleForTesting; /** * Worker is the high level class that Kinesis applications use to start @@ -342,46 +342,49 @@ public class Worker implements Runnable { } while (!shouldShutdown()) { - try { - boolean foundCompletedShard = false; - Set assignedShards = new HashSet(); - for (ShardInfo shardInfo : getShardInfoForAssignments()) { - ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory); - if (shardConsumer.isShutdown() - && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) { - foundCompletedShard = true; - } else { - shardConsumer.consumeShard(); - } - assignedShards.add(shardInfo); - } - - if (foundCompletedShard) { - controlServer.syncShardAndLeaseInfo(null); - } - - // clean up shard consumers for unassigned shards - cleanupShardConsumers(assignedShards); - - wlog.info("Sleeping ..."); - Thread.sleep(idleTimeInMilliseconds); - } catch (Exception e) { - LOG.error(String.format("Worker.run caught exception, sleeping for %s milli seconds!", - String.valueOf(idleTimeInMilliseconds)), - e); - try { - Thread.sleep(idleTimeInMilliseconds); - } catch (InterruptedException ex) { - LOG.info("Worker: sleep interrupted after catching exception ", ex); - } - } - wlog.resetInfoLogging(); + runProcessLoop(); } finalShutdown(); LOG.info("Worker loop is complete. Exiting from worker."); } + @VisibleForTesting + void runProcessLoop() { + try { + boolean foundCompletedShard = false; + Set assignedShards = new HashSet<>(); + for (ShardInfo shardInfo : getShardInfoForAssignments()) { + ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory); + if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) { + foundCompletedShard = true; + } else { + shardConsumer.consumeShard(); + } + assignedShards.add(shardInfo); + } + + if (foundCompletedShard) { + controlServer.syncShardAndLeaseInfo(null); + } + + // clean up shard consumers for unassigned shards + cleanupShardConsumers(assignedShards); + + wlog.info("Sleeping ..."); + Thread.sleep(idleTimeInMilliseconds); + } catch (Exception e) { + LOG.error(String.format("Worker.run caught exception, sleeping for %s milli seconds!", + String.valueOf(idleTimeInMilliseconds)), e); + try { + Thread.sleep(idleTimeInMilliseconds); + } catch (InterruptedException ex) { + LOG.info("Worker: sleep interrupted after catching exception ", ex); + } + } + wlog.resetInfoLogging(); + } + private void initialize() { boolean isDone = false; Exception lastException = null; @@ -552,25 +555,22 @@ public class Worker implements Runnable { // completely processed (shutdown reason terminate). if ((consumer == null) || (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) { - IRecordProcessor recordProcessor = factory.createProcessor(); - - consumer = - new ShardConsumer(shardInfo, - streamConfig, - checkpointTracker, - recordProcessor, - leaseCoordinator.getLeaseManager(), - parentShardPollIntervalMillis, - cleanupLeasesUponShardCompletion, - executorService, - metricsFactory, - taskBackoffTimeMillis); + consumer = buildConsumer(shardInfo, factory); shardInfoShardConsumerMap.put(shardInfo, consumer); wlog.infoForce("Created new shardConsumer for : " + shardInfo); } return consumer; } + protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory factory) { + IRecordProcessor recordProcessor = factory.createProcessor(); + + return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, + leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion, + executorService, metricsFactory, taskBackoffTimeMillis); + + } + /** * Logger for suppressing too much INFO logging. To avoid too much logging * information Worker will output logging at INFO level for a single pass diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ParentsFirstShardPrioritizationUnitTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ParentsFirstShardPrioritizationUnitTest.java index 35b56b32..7ba0753d 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ParentsFirstShardPrioritizationUnitTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ParentsFirstShardPrioritizationUnitTest.java @@ -11,6 +11,8 @@ import java.util.Random; import org.junit.Test; +import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; + public class ParentsFirstShardPrioritizationUnitTest { @Test(expected = IllegalArgumentException.class) @@ -144,17 +146,54 @@ public class ParentsFirstShardPrioritizationUnitTest { return "shardId-" + shardNumber; } + /** + * Builder class for ShardInfo. + */ + static class ShardInfoBuilder { + private String shardId; + private String concurrencyToken; + private List parentShardIds = Collections.emptyList(); + private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST; + + ShardInfoBuilder() { + } + + ShardInfoBuilder withShardId(String shardId) { + this.shardId = shardId; + return this; + } + + ShardInfoBuilder withConcurrencyToken(String concurrencyToken) { + this.concurrencyToken = concurrencyToken; + return this; + } + + ShardInfoBuilder withParentShards(List parentShardIds) { + this.parentShardIds = parentShardIds; + return this; + } + + ShardInfoBuilder withCheckpoint(ExtendedSequenceNumber checkpoint) { + this.checkpoint = checkpoint; + return this; + } + + ShardInfo build() { + return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint); + } + } + private static ShardInfo shardInfo(String shardId, List parentShardIds) { // copy into new list just in case ShardInfo will stop doing it List newParentShardIds = new ArrayList<>(parentShardIds); - return new ShardInfo.Builder() + return new ShardInfoBuilder() .withShardId(shardId) .withParentShards(newParentShardIds) .build(); } private static ShardInfo shardInfo(String shardId, String... parentShardIds) { - return new ShardInfo.Builder() + return new ShardInfoBuilder() .withShardId(shardId) .withParentShards(Arrays.asList(parentShardIds)) .build(); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfoTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfoTest.java index a2434d69..511b5a1b 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfoTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfoTest.java @@ -14,6 +14,10 @@ */ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertThat; + import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -93,11 +97,20 @@ public class ShardInfoTest { } @Test - public void testPacboyShardInfoEqualsForCheckpoint() { - ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.SHARD_END); - Assert.assertFalse("Equal should return false with different checkpoint", diffShardInfo.equals(testShardInfo)); - diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null); - Assert.assertFalse("Equal should return false with null checkpoint", diffShardInfo.equals(testShardInfo)); + public void testShardInfoCheckpointEqualsHashCode() { + ShardInfo baseInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, + ExtendedSequenceNumber.TRIM_HORIZON); + ShardInfo differentCheckpoint = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, + new ExtendedSequenceNumber("1234")); + ShardInfo nullCheckpoint = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null); + + assertThat("Checkpoint should not be included in equality.", baseInfo.equals(differentCheckpoint), is(true)); + assertThat("Checkpoint should not be included in equality.", baseInfo.equals(nullCheckpoint), is(true)); + + assertThat("Checkpoint should not be included in hash code.", baseInfo.hashCode(), + equalTo(differentCheckpoint.hashCode())); + assertThat("Checkpoint should not be included in hash code.", baseInfo.hashCode(), + equalTo(nullCheckpoint.hashCode())); } @Test diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java index a68c229b..0747f83d 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java @@ -16,10 +16,16 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -50,7 +56,10 @@ import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; @@ -85,6 +94,7 @@ import com.amazonaws.services.kinesis.model.Shard; /** * Unit tests of Worker. */ +@RunWith(MockitoJUnitRunner.class) public class WorkerTest { private static final Log LOG = LogFactory.getLog(WorkerTest.class); @@ -107,6 +117,30 @@ public class WorkerTest { InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON); private final ShardPrioritization shardPrioritization = new NoOpShardPrioritization(); + private static final String KINESIS_SHARD_ID_FORMAT = "kinesis-0-0-%d"; + private static final String CONCURRENCY_TOKEN_FORMAT = "testToken-%d"; + + @Mock + private KinesisClientLibLeaseCoordinator leaseCoordinator; + @Mock + private ILeaseManager leaseManager; + @Mock + private com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory v1RecordProcessorFactory; + @Mock + private IKinesisProxy proxy; + @Mock + private WorkerThreadPoolExecutor executorService; + @Mock + private WorkerCWMetricsFactory cwMetricsFactory; + @Mock + private IKinesisProxy kinesisProxy; + @Mock + private IRecordProcessorFactory v2RecordProcessorFactory; + @Mock + private IRecordProcessor v2RecordProcessor; + @Mock + private ShardConsumer shardConsumer; + // CHECKSTYLE:IGNORE AnonInnerLengthCheck FOR NEXT 50 LINES private static final com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY = new com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory() { @@ -145,6 +179,7 @@ public class WorkerTest { private static final IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY_V2 = new V1ToV2RecordProcessorFactoryAdapter(SAMPLE_RECORD_PROCESSOR_FACTORY); + /** * Test method for {@link com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker#getApplicationName()}. */ @@ -153,9 +188,7 @@ public class WorkerTest { final String stageName = "testStageName"; final KinesisClientLibConfiguration clientConfig = new KinesisClientLibConfiguration(stageName, null, null, null); - Worker worker = - new Worker(mock(com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory.class), - clientConfig); + Worker worker = new Worker(v1RecordProcessorFactory, clientConfig); Assert.assertEquals(stageName, worker.getApplicationName()); } @@ -177,9 +210,6 @@ public class WorkerTest { final String dummyKinesisShardId = "kinesis-0-0"; ExecutorService execService = null; - KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class); - @SuppressWarnings("unchecked") - ILeaseManager leaseManager = mock(ILeaseManager.class); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); Worker worker = @@ -208,6 +238,63 @@ public class WorkerTest { Assert.assertNotSame(consumer3, consumer); } + @Test + public void testWorkerLoopWithCheckpoint() { + final String stageName = "testStageName"; + IRecordProcessorFactory streamletFactory = SAMPLE_RECORD_PROCESSOR_FACTORY_V2; + IKinesisProxy proxy = null; + ICheckpoint checkpoint = null; + int maxRecords = 1; + int idleTimeInMilliseconds = 1000; + StreamConfig streamConfig = new StreamConfig(proxy, maxRecords, idleTimeInMilliseconds, + callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST); + + ExecutorService execService = null; + + when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); + + List initialState = createShardInfoList(ExtendedSequenceNumber.TRIM_HORIZON); + List firstCheckpoint = createShardInfoList(new ExtendedSequenceNumber("1000")); + List secondCheckpoint = createShardInfoList(new ExtendedSequenceNumber("2000")); + + when(leaseCoordinator.getCurrentAssignments()).thenReturn(initialState).thenReturn(firstCheckpoint) + .thenReturn(secondCheckpoint); + + Worker worker = new Worker(stageName, streamletFactory, streamConfig, INITIAL_POSITION_LATEST, + parentShardPollIntervalMillis, shardSyncIntervalMillis, cleanupLeasesUponShardCompletion, checkpoint, + leaseCoordinator, execService, nullMetricsFactory, taskBackoffTimeMillis, failoverTimeMillis, + shardPrioritization); + + Worker workerSpy = spy(worker); + + doReturn(shardConsumer).when(workerSpy).buildConsumer(eq(initialState.get(0)), any(IRecordProcessorFactory.class)); + workerSpy.runProcessLoop(); + workerSpy.runProcessLoop(); + workerSpy.runProcessLoop(); + + verify(workerSpy).buildConsumer(same(initialState.get(0)), any(IRecordProcessorFactory.class)); + verify(workerSpy, never()).buildConsumer(same(firstCheckpoint.get(0)), any(IRecordProcessorFactory.class)); + verify(workerSpy, never()).buildConsumer(same(secondCheckpoint.get(0)), any(IRecordProcessorFactory.class)); + + } + + private List createShardInfoList(ExtendedSequenceNumber... sequenceNumbers) { + List result = new ArrayList<>(sequenceNumbers.length); + assertThat(sequenceNumbers.length, greaterThanOrEqualTo(1)); + for (int i = 0; i < sequenceNumbers.length; ++i) { + result.add(new ShardInfo(adjustedShardId(i), adjustedConcurrencyToken(i), null, sequenceNumbers[i])); + } + return result; + } + + private String adjustedShardId(int index) { + return String.format(KINESIS_SHARD_ID_FORMAT, index); + } + + private String adjustedConcurrencyToken(int index) { + return String.format(CONCURRENCY_TOKEN_FORMAT, index); + } + @Test public final void testCleanupShardConsumers() { final String stageName = "testStageName"; @@ -226,10 +313,6 @@ public class WorkerTest { final String dummyKinesisShardId = "kinesis-0-0"; final String anotherDummyKinesisShardId = "kinesis-0-1"; ExecutorService execService = null; - - KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class); - @SuppressWarnings("unchecked") - ILeaseManager leaseManager = mock(ILeaseManager.class); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); Worker worker = @@ -272,7 +355,6 @@ public class WorkerTest { public final void testInitializationFailureWithRetries() { String stageName = "testInitializationWorker"; IRecordProcessorFactory recordProcessorFactory = new TestStreamletFactory(null, null); - IKinesisProxy proxy = mock(IKinesisProxy.class); int count = 0; when(proxy.getShardList()).thenThrow(new RuntimeException(Integer.toString(count++))); int maxRecords = 2; @@ -282,9 +364,6 @@ public class WorkerTest { maxRecords, idleTimeInMilliseconds, callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST); - KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class); - @SuppressWarnings("unchecked") - ILeaseManager leaseManager = mock(ILeaseManager.class); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); ExecutorService execService = Executors.newSingleThreadExecutor(); long shardPollInterval = 0L; @@ -374,8 +453,7 @@ public class WorkerTest { @Test public final void testWorkerShutsDownOwnedResources() throws Exception { - final WorkerThreadPoolExecutor executorService = mock(WorkerThreadPoolExecutor.class); - final WorkerCWMetricsFactory cwMetricsFactory = mock(WorkerCWMetricsFactory.class); + final long failoverTimeMillis = 20L; // Make sure that worker thread is run before invoking shutdown. @@ -393,8 +471,7 @@ public class WorkerTest { callProcessRecordsForEmptyRecordList, failoverTimeMillis, 10, - mock(IKinesisProxy.class), - mock(IRecordProcessorFactory.class), + kinesisProxy, v2RecordProcessorFactory, executorService, cwMetricsFactory); @@ -411,10 +488,10 @@ public class WorkerTest { @Test public final void testWorkerDoesNotShutdownClientResources() throws Exception { - final ExecutorService executorService = mock(ThreadPoolExecutor.class); - final CWMetricsFactory cwMetricsFactory = mock(CWMetricsFactory.class); final long failoverTimeMillis = 20L; + final ExecutorService executorService = mock(ThreadPoolExecutor.class); + final CWMetricsFactory cwMetricsFactory = mock(CWMetricsFactory.class); // Make sure that worker thread is run before invoking shutdown. final CountDownLatch workerStarted = new CountDownLatch(1); doAnswer(new Answer() { @@ -430,8 +507,7 @@ public class WorkerTest { callProcessRecordsForEmptyRecordList, failoverTimeMillis, 10, - mock(IKinesisProxy.class), - mock(IRecordProcessorFactory.class), + kinesisProxy, v2RecordProcessorFactory, executorService, cwMetricsFactory); @@ -468,9 +544,8 @@ public class WorkerTest { // Make test case as efficient as possible. final CountDownLatch processRecordsLatch = new CountDownLatch(1); - IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); - IRecordProcessor recordProcessor = mock(IRecordProcessor.class); - when(recordProcessorFactory.createProcessor()).thenReturn(recordProcessor); + + when(v2RecordProcessorFactory.createProcessor()).thenReturn(v2RecordProcessor); doAnswer(new Answer () { @Override @@ -479,7 +554,7 @@ public class WorkerTest { processRecordsLatch.countDown(); return null; } - }).when(recordProcessor).processRecords(any(ProcessRecordsInput.class)); + }).when(v2RecordProcessor).processRecords(any(ProcessRecordsInput.class)); WorkerThread workerThread = runWorker(shardList, initialLeases, @@ -487,7 +562,7 @@ public class WorkerTest { failoverTimeMillis, numberOfRecordsPerShard, fileBasedProxy, - recordProcessorFactory, + v2RecordProcessorFactory, executorService, nullMetricsFactory); @@ -495,16 +570,16 @@ public class WorkerTest { processRecordsLatch.await(); // Make sure record processor is initialized and processing records. - verify(recordProcessorFactory, times(1)).createProcessor(); - verify(recordProcessor, times(1)).initialize(any(InitializationInput.class)); - verify(recordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); - verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); + verify(v2RecordProcessorFactory, times(1)).createProcessor(); + verify(v2RecordProcessor, times(1)).initialize(any(InitializationInput.class)); + verify(v2RecordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); + verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class)); workerThread.getWorker().shutdown(); workerThread.join(); Assert.assertTrue(workerThread.getState() == State.TERMINATED); - verify(recordProcessor, times(1)).shutdown(any(ShutdownInput.class)); + verify(v2RecordProcessor, times(1)).shutdown(any(ShutdownInput.class)); } /** @@ -538,9 +613,7 @@ public class WorkerTest { // Make test case as efficient as possible. final CountDownLatch processRecordsLatch = new CountDownLatch(1); final AtomicBoolean recordProcessorInterrupted = new AtomicBoolean(false); - IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); - IRecordProcessor recordProcessor = mock(IRecordProcessor.class); - when(recordProcessorFactory.createProcessor()).thenReturn(recordProcessor); + when(v2RecordProcessorFactory.createProcessor()).thenReturn(v2RecordProcessor); final Semaphore actionBlocker = new Semaphore(1); final Semaphore shutdownBlocker = new Semaphore(1); @@ -572,7 +645,7 @@ public class WorkerTest { return null; } - }).when(recordProcessor).processRecords(any(ProcessRecordsInput.class)); + }).when(v2RecordProcessor).processRecords(any(ProcessRecordsInput.class)); WorkerThread workerThread = runWorker(shardList, initialLeases, @@ -580,7 +653,7 @@ public class WorkerTest { failoverTimeMillis, numberOfRecordsPerShard, fileBasedProxy, - recordProcessorFactory, + v2RecordProcessorFactory, executorService, nullMetricsFactory); @@ -588,17 +661,17 @@ public class WorkerTest { processRecordsLatch.await(); // Make sure record processor is initialized and processing records. - verify(recordProcessorFactory, times(1)).createProcessor(); - verify(recordProcessor, times(1)).initialize(any(InitializationInput.class)); - verify(recordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); - verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); + verify(v2RecordProcessorFactory, times(1)).createProcessor(); + verify(v2RecordProcessor, times(1)).initialize(any(InitializationInput.class)); + verify(v2RecordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); + verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class)); workerThread.getWorker().shutdown(); workerThread.join(); Assert.assertTrue(workerThread.getState() == State.TERMINATED); // Shutdown should not be called in this case because record processor is blocked. - verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); + verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class)); // // Release the worker thread