From 6ea3c0f8ddef8ad0f99b1a9d3e787dc051e87ff8 Mon Sep 17 00:00:00 2001 From: "Pfifer, Justin" Date: Wed, 17 Aug 2016 10:56:27 -0700 Subject: [PATCH] Don't Use Checkpoint for Equality in ShardInfo Don't include checkpoint in hashCode/equality for ShardInfo, since it changes. Checkpointing would cause the Worker to recreate the ShardConsumer. Add unit tests that verify the equality constraints. Remove the equality test for checkpoint from ShardInfoTests. Nothing appears to rely on the checkpoint being part of ShardInfo. Fix WorkerTest broken in overzealous simplification. --- .../clientlibrary/lib/worker/ShardInfo.java | 97 ++--------- .../clientlibrary/lib/worker/Worker.java | 96 +++++------ ...rentsFirstShardPrioritizationUnitTest.java | 43 ++++- .../lib/worker/ShardInfoTest.java | 23 ++- .../clientlibrary/lib/worker/WorkerTest.java | 157 +++++++++++++----- 5 files changed, 240 insertions(+), 176 deletions(-) diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfo.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfo.java index 9890d02f..f04a86ba 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfo.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfo.java @@ -19,13 +19,17 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; +import org.apache.commons.lang.builder.EqualsBuilder; + import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; +import org.apache.commons.lang.builder.HashCodeBuilder; /** * Used to pass shard related info among different classes and as a key to the map of shard consumers. */ class ShardInfo { + private final String shardId; private final String concurrencyToken; // Sorted list of parent shardIds. @@ -33,10 +37,16 @@ class ShardInfo { private final ExtendedSequenceNumber checkpoint; /** - * @param shardId Kinesis shardId - * @param concurrencyToken Used to differentiate between lost and reclaimed leases - * @param parentShardIds Parent shards of the shard identified by Kinesis shardId - * @param checkpoint the latest checkpoint from lease + * Creates a new ShardInfo object. The checkpoint is not part of the equality, but is used for debugging output. + * + * @param shardId + * Kinesis shardId + * @param concurrencyToken + * Used to differentiate between lost and reclaimed leases + * @param parentShardIds + * Parent shards of the shard identified by Kinesis shardId + * @param checkpoint + * the latest checkpoint from lease */ public ShardInfo(String shardId, String concurrencyToken, @@ -87,20 +97,12 @@ class ShardInfo { */ @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((concurrencyToken == null) ? 0 : concurrencyToken.hashCode()); - result = prime * result + ((parentShardIds == null) ? 0 : parentShardIds.hashCode()); - result = prime * result + ((shardId == null) ? 0 : shardId.hashCode()); - result = prime * result + ((checkpoint == null) ? 0 : checkpoint.hashCode()); - return result; + return new HashCodeBuilder().append(concurrencyToken).append(parentShardIds).append(shardId).toHashCode(); } /** * {@inheritDoc} */ - // CHECKSTYLE:OFF CyclomaticComplexity - // CHECKSTYLE:OFF NPathComplexity /** * This method assumes parentShardIds is ordered. The Worker.cleanupShardConsumers() method relies on this method * returning true for ShardInfo objects which may have been instantiated with parentShardIds in a different order @@ -121,39 +123,11 @@ class ShardInfo { return false; } ShardInfo other = (ShardInfo) obj; - if (concurrencyToken == null) { - if (other.concurrencyToken != null) { - return false; - } - } else if (!concurrencyToken.equals(other.concurrencyToken)) { - return false; - } - if (parentShardIds == null) { - if (other.parentShardIds != null) { - return false; - } - } else if (!parentShardIds.equals(other.parentShardIds)) { - return false; - } - if (shardId == null) { - if (other.shardId != null) { - return false; - } - } else if (!shardId.equals(other.shardId)) { - return false; - } - if (checkpoint == null) { - if (other.checkpoint != null) { - return false; - } - } else if (!checkpoint.equals(other.checkpoint)) { - return false; - } - return true; + return new EqualsBuilder().append(concurrencyToken, other.concurrencyToken) + .append(parentShardIds, other.parentShardIds).append(shardId, other.shardId).isEquals(); + } - // CHECKSTYLE:ON CyclomaticComplexity - // CHECKSTYLE:ON NPathComplexity @Override public String toString() { @@ -161,41 +135,6 @@ class ShardInfo { + parentShardIds + ", checkpoint=" + checkpoint + "]"; } - /** - * Builder class for ShardInfo. - */ - public static class Builder { - private String shardId; - private String concurrencyToken; - private List parentShardIds = Collections.emptyList(); - private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST; - public Builder() { - } - - public Builder withShardId(String shardId) { - this.shardId = shardId; - return this; - } - - public Builder withConcurrencyToken(String concurrencyToken) { - this.concurrencyToken = concurrencyToken; - return this; - } - - public Builder withParentShards(List parentShardIds) { - this.parentShardIds = parentShardIds; - return this; - } - - public Builder withCheckpoint(ExtendedSequenceNumber checkpoint) { - this.checkpoint = checkpoint; - return this; - } - - public ShardInfo build() { - return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint); - } - } } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java index 3da1f2cd..32efa442 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/Worker.java @@ -38,7 +38,6 @@ import com.amazonaws.services.kinesis.AmazonKinesisClient; import com.amazonaws.services.kinesis.clientlibrary.interfaces.ICheckpoint; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessorFactory; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker.Builder; import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxyFactory; import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason; import com.amazonaws.services.kinesis.leases.exceptions.LeasingException; @@ -47,6 +46,7 @@ import com.amazonaws.services.kinesis.metrics.impl.CWMetricsFactory; import com.amazonaws.services.kinesis.metrics.impl.NullMetricsFactory; import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel; +import com.google.common.annotations.VisibleForTesting; /** * Worker is the high level class that Kinesis applications use to start @@ -342,46 +342,49 @@ public class Worker implements Runnable { } while (!shouldShutdown()) { - try { - boolean foundCompletedShard = false; - Set assignedShards = new HashSet(); - for (ShardInfo shardInfo : getShardInfoForAssignments()) { - ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory); - if (shardConsumer.isShutdown() - && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) { - foundCompletedShard = true; - } else { - shardConsumer.consumeShard(); - } - assignedShards.add(shardInfo); - } - - if (foundCompletedShard) { - controlServer.syncShardAndLeaseInfo(null); - } - - // clean up shard consumers for unassigned shards - cleanupShardConsumers(assignedShards); - - wlog.info("Sleeping ..."); - Thread.sleep(idleTimeInMilliseconds); - } catch (Exception e) { - LOG.error(String.format("Worker.run caught exception, sleeping for %s milli seconds!", - String.valueOf(idleTimeInMilliseconds)), - e); - try { - Thread.sleep(idleTimeInMilliseconds); - } catch (InterruptedException ex) { - LOG.info("Worker: sleep interrupted after catching exception ", ex); - } - } - wlog.resetInfoLogging(); + runProcessLoop(); } finalShutdown(); LOG.info("Worker loop is complete. Exiting from worker."); } + @VisibleForTesting + void runProcessLoop() { + try { + boolean foundCompletedShard = false; + Set assignedShards = new HashSet<>(); + for (ShardInfo shardInfo : getShardInfoForAssignments()) { + ShardConsumer shardConsumer = createOrGetShardConsumer(shardInfo, recordProcessorFactory); + if (shardConsumer.isShutdown() && shardConsumer.getShutdownReason().equals(ShutdownReason.TERMINATE)) { + foundCompletedShard = true; + } else { + shardConsumer.consumeShard(); + } + assignedShards.add(shardInfo); + } + + if (foundCompletedShard) { + controlServer.syncShardAndLeaseInfo(null); + } + + // clean up shard consumers for unassigned shards + cleanupShardConsumers(assignedShards); + + wlog.info("Sleeping ..."); + Thread.sleep(idleTimeInMilliseconds); + } catch (Exception e) { + LOG.error(String.format("Worker.run caught exception, sleeping for %s milli seconds!", + String.valueOf(idleTimeInMilliseconds)), e); + try { + Thread.sleep(idleTimeInMilliseconds); + } catch (InterruptedException ex) { + LOG.info("Worker: sleep interrupted after catching exception ", ex); + } + } + wlog.resetInfoLogging(); + } + private void initialize() { boolean isDone = false; Exception lastException = null; @@ -552,25 +555,22 @@ public class Worker implements Runnable { // completely processed (shutdown reason terminate). if ((consumer == null) || (consumer.isShutdown() && consumer.getShutdownReason().equals(ShutdownReason.ZOMBIE))) { - IRecordProcessor recordProcessor = factory.createProcessor(); - - consumer = - new ShardConsumer(shardInfo, - streamConfig, - checkpointTracker, - recordProcessor, - leaseCoordinator.getLeaseManager(), - parentShardPollIntervalMillis, - cleanupLeasesUponShardCompletion, - executorService, - metricsFactory, - taskBackoffTimeMillis); + consumer = buildConsumer(shardInfo, factory); shardInfoShardConsumerMap.put(shardInfo, consumer); wlog.infoForce("Created new shardConsumer for : " + shardInfo); } return consumer; } + protected ShardConsumer buildConsumer(ShardInfo shardInfo, IRecordProcessorFactory factory) { + IRecordProcessor recordProcessor = factory.createProcessor(); + + return new ShardConsumer(shardInfo, streamConfig, checkpointTracker, recordProcessor, + leaseCoordinator.getLeaseManager(), parentShardPollIntervalMillis, cleanupLeasesUponShardCompletion, + executorService, metricsFactory, taskBackoffTimeMillis); + + } + /** * Logger for suppressing too much INFO logging. To avoid too much logging * information Worker will output logging at INFO level for a single pass diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ParentsFirstShardPrioritizationUnitTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ParentsFirstShardPrioritizationUnitTest.java index 35b56b32..7ba0753d 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ParentsFirstShardPrioritizationUnitTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ParentsFirstShardPrioritizationUnitTest.java @@ -11,6 +11,8 @@ import java.util.Random; import org.junit.Test; +import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; + public class ParentsFirstShardPrioritizationUnitTest { @Test(expected = IllegalArgumentException.class) @@ -144,17 +146,54 @@ public class ParentsFirstShardPrioritizationUnitTest { return "shardId-" + shardNumber; } + /** + * Builder class for ShardInfo. + */ + static class ShardInfoBuilder { + private String shardId; + private String concurrencyToken; + private List parentShardIds = Collections.emptyList(); + private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST; + + ShardInfoBuilder() { + } + + ShardInfoBuilder withShardId(String shardId) { + this.shardId = shardId; + return this; + } + + ShardInfoBuilder withConcurrencyToken(String concurrencyToken) { + this.concurrencyToken = concurrencyToken; + return this; + } + + ShardInfoBuilder withParentShards(List parentShardIds) { + this.parentShardIds = parentShardIds; + return this; + } + + ShardInfoBuilder withCheckpoint(ExtendedSequenceNumber checkpoint) { + this.checkpoint = checkpoint; + return this; + } + + ShardInfo build() { + return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint); + } + } + private static ShardInfo shardInfo(String shardId, List parentShardIds) { // copy into new list just in case ShardInfo will stop doing it List newParentShardIds = new ArrayList<>(parentShardIds); - return new ShardInfo.Builder() + return new ShardInfoBuilder() .withShardId(shardId) .withParentShards(newParentShardIds) .build(); } private static ShardInfo shardInfo(String shardId, String... parentShardIds) { - return new ShardInfo.Builder() + return new ShardInfoBuilder() .withShardId(shardId) .withParentShards(Arrays.asList(parentShardIds)) .build(); diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfoTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfoTest.java index a2434d69..511b5a1b 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfoTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardInfoTest.java @@ -14,6 +14,10 @@ */ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertThat; + import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -93,11 +97,20 @@ public class ShardInfoTest { } @Test - public void testPacboyShardInfoEqualsForCheckpoint() { - ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.SHARD_END); - Assert.assertFalse("Equal should return false with different checkpoint", diffShardInfo.equals(testShardInfo)); - diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null); - Assert.assertFalse("Equal should return false with null checkpoint", diffShardInfo.equals(testShardInfo)); + public void testShardInfoCheckpointEqualsHashCode() { + ShardInfo baseInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, + ExtendedSequenceNumber.TRIM_HORIZON); + ShardInfo differentCheckpoint = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, + new ExtendedSequenceNumber("1234")); + ShardInfo nullCheckpoint = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null); + + assertThat("Checkpoint should not be included in equality.", baseInfo.equals(differentCheckpoint), is(true)); + assertThat("Checkpoint should not be included in equality.", baseInfo.equals(nullCheckpoint), is(true)); + + assertThat("Checkpoint should not be included in hash code.", baseInfo.hashCode(), + equalTo(differentCheckpoint.hashCode())); + assertThat("Checkpoint should not be included in hash code.", baseInfo.hashCode(), + equalTo(nullCheckpoint.hashCode())); } @Test diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java index a68c229b..0747f83d 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/WorkerTest.java @@ -16,10 +16,16 @@ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -50,7 +56,10 @@ import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; @@ -85,6 +94,7 @@ import com.amazonaws.services.kinesis.model.Shard; /** * Unit tests of Worker. */ +@RunWith(MockitoJUnitRunner.class) public class WorkerTest { private static final Log LOG = LogFactory.getLog(WorkerTest.class); @@ -107,6 +117,30 @@ public class WorkerTest { InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON); private final ShardPrioritization shardPrioritization = new NoOpShardPrioritization(); + private static final String KINESIS_SHARD_ID_FORMAT = "kinesis-0-0-%d"; + private static final String CONCURRENCY_TOKEN_FORMAT = "testToken-%d"; + + @Mock + private KinesisClientLibLeaseCoordinator leaseCoordinator; + @Mock + private ILeaseManager leaseManager; + @Mock + private com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory v1RecordProcessorFactory; + @Mock + private IKinesisProxy proxy; + @Mock + private WorkerThreadPoolExecutor executorService; + @Mock + private WorkerCWMetricsFactory cwMetricsFactory; + @Mock + private IKinesisProxy kinesisProxy; + @Mock + private IRecordProcessorFactory v2RecordProcessorFactory; + @Mock + private IRecordProcessor v2RecordProcessor; + @Mock + private ShardConsumer shardConsumer; + // CHECKSTYLE:IGNORE AnonInnerLengthCheck FOR NEXT 50 LINES private static final com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY = new com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory() { @@ -145,6 +179,7 @@ public class WorkerTest { private static final IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY_V2 = new V1ToV2RecordProcessorFactoryAdapter(SAMPLE_RECORD_PROCESSOR_FACTORY); + /** * Test method for {@link com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker#getApplicationName()}. */ @@ -153,9 +188,7 @@ public class WorkerTest { final String stageName = "testStageName"; final KinesisClientLibConfiguration clientConfig = new KinesisClientLibConfiguration(stageName, null, null, null); - Worker worker = - new Worker(mock(com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory.class), - clientConfig); + Worker worker = new Worker(v1RecordProcessorFactory, clientConfig); Assert.assertEquals(stageName, worker.getApplicationName()); } @@ -177,9 +210,6 @@ public class WorkerTest { final String dummyKinesisShardId = "kinesis-0-0"; ExecutorService execService = null; - KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class); - @SuppressWarnings("unchecked") - ILeaseManager leaseManager = mock(ILeaseManager.class); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); Worker worker = @@ -208,6 +238,63 @@ public class WorkerTest { Assert.assertNotSame(consumer3, consumer); } + @Test + public void testWorkerLoopWithCheckpoint() { + final String stageName = "testStageName"; + IRecordProcessorFactory streamletFactory = SAMPLE_RECORD_PROCESSOR_FACTORY_V2; + IKinesisProxy proxy = null; + ICheckpoint checkpoint = null; + int maxRecords = 1; + int idleTimeInMilliseconds = 1000; + StreamConfig streamConfig = new StreamConfig(proxy, maxRecords, idleTimeInMilliseconds, + callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST); + + ExecutorService execService = null; + + when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); + + List initialState = createShardInfoList(ExtendedSequenceNumber.TRIM_HORIZON); + List firstCheckpoint = createShardInfoList(new ExtendedSequenceNumber("1000")); + List secondCheckpoint = createShardInfoList(new ExtendedSequenceNumber("2000")); + + when(leaseCoordinator.getCurrentAssignments()).thenReturn(initialState).thenReturn(firstCheckpoint) + .thenReturn(secondCheckpoint); + + Worker worker = new Worker(stageName, streamletFactory, streamConfig, INITIAL_POSITION_LATEST, + parentShardPollIntervalMillis, shardSyncIntervalMillis, cleanupLeasesUponShardCompletion, checkpoint, + leaseCoordinator, execService, nullMetricsFactory, taskBackoffTimeMillis, failoverTimeMillis, + shardPrioritization); + + Worker workerSpy = spy(worker); + + doReturn(shardConsumer).when(workerSpy).buildConsumer(eq(initialState.get(0)), any(IRecordProcessorFactory.class)); + workerSpy.runProcessLoop(); + workerSpy.runProcessLoop(); + workerSpy.runProcessLoop(); + + verify(workerSpy).buildConsumer(same(initialState.get(0)), any(IRecordProcessorFactory.class)); + verify(workerSpy, never()).buildConsumer(same(firstCheckpoint.get(0)), any(IRecordProcessorFactory.class)); + verify(workerSpy, never()).buildConsumer(same(secondCheckpoint.get(0)), any(IRecordProcessorFactory.class)); + + } + + private List createShardInfoList(ExtendedSequenceNumber... sequenceNumbers) { + List result = new ArrayList<>(sequenceNumbers.length); + assertThat(sequenceNumbers.length, greaterThanOrEqualTo(1)); + for (int i = 0; i < sequenceNumbers.length; ++i) { + result.add(new ShardInfo(adjustedShardId(i), adjustedConcurrencyToken(i), null, sequenceNumbers[i])); + } + return result; + } + + private String adjustedShardId(int index) { + return String.format(KINESIS_SHARD_ID_FORMAT, index); + } + + private String adjustedConcurrencyToken(int index) { + return String.format(CONCURRENCY_TOKEN_FORMAT, index); + } + @Test public final void testCleanupShardConsumers() { final String stageName = "testStageName"; @@ -226,10 +313,6 @@ public class WorkerTest { final String dummyKinesisShardId = "kinesis-0-0"; final String anotherDummyKinesisShardId = "kinesis-0-1"; ExecutorService execService = null; - - KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class); - @SuppressWarnings("unchecked") - ILeaseManager leaseManager = mock(ILeaseManager.class); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); Worker worker = @@ -272,7 +355,6 @@ public class WorkerTest { public final void testInitializationFailureWithRetries() { String stageName = "testInitializationWorker"; IRecordProcessorFactory recordProcessorFactory = new TestStreamletFactory(null, null); - IKinesisProxy proxy = mock(IKinesisProxy.class); int count = 0; when(proxy.getShardList()).thenThrow(new RuntimeException(Integer.toString(count++))); int maxRecords = 2; @@ -282,9 +364,6 @@ public class WorkerTest { maxRecords, idleTimeInMilliseconds, callProcessRecordsForEmptyRecordList, skipCheckpointValidationValue, INITIAL_POSITION_LATEST); - KinesisClientLibLeaseCoordinator leaseCoordinator = mock(KinesisClientLibLeaseCoordinator.class); - @SuppressWarnings("unchecked") - ILeaseManager leaseManager = mock(ILeaseManager.class); when(leaseCoordinator.getLeaseManager()).thenReturn(leaseManager); ExecutorService execService = Executors.newSingleThreadExecutor(); long shardPollInterval = 0L; @@ -374,8 +453,7 @@ public class WorkerTest { @Test public final void testWorkerShutsDownOwnedResources() throws Exception { - final WorkerThreadPoolExecutor executorService = mock(WorkerThreadPoolExecutor.class); - final WorkerCWMetricsFactory cwMetricsFactory = mock(WorkerCWMetricsFactory.class); + final long failoverTimeMillis = 20L; // Make sure that worker thread is run before invoking shutdown. @@ -393,8 +471,7 @@ public class WorkerTest { callProcessRecordsForEmptyRecordList, failoverTimeMillis, 10, - mock(IKinesisProxy.class), - mock(IRecordProcessorFactory.class), + kinesisProxy, v2RecordProcessorFactory, executorService, cwMetricsFactory); @@ -411,10 +488,10 @@ public class WorkerTest { @Test public final void testWorkerDoesNotShutdownClientResources() throws Exception { - final ExecutorService executorService = mock(ThreadPoolExecutor.class); - final CWMetricsFactory cwMetricsFactory = mock(CWMetricsFactory.class); final long failoverTimeMillis = 20L; + final ExecutorService executorService = mock(ThreadPoolExecutor.class); + final CWMetricsFactory cwMetricsFactory = mock(CWMetricsFactory.class); // Make sure that worker thread is run before invoking shutdown. final CountDownLatch workerStarted = new CountDownLatch(1); doAnswer(new Answer() { @@ -430,8 +507,7 @@ public class WorkerTest { callProcessRecordsForEmptyRecordList, failoverTimeMillis, 10, - mock(IKinesisProxy.class), - mock(IRecordProcessorFactory.class), + kinesisProxy, v2RecordProcessorFactory, executorService, cwMetricsFactory); @@ -468,9 +544,8 @@ public class WorkerTest { // Make test case as efficient as possible. final CountDownLatch processRecordsLatch = new CountDownLatch(1); - IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); - IRecordProcessor recordProcessor = mock(IRecordProcessor.class); - when(recordProcessorFactory.createProcessor()).thenReturn(recordProcessor); + + when(v2RecordProcessorFactory.createProcessor()).thenReturn(v2RecordProcessor); doAnswer(new Answer () { @Override @@ -479,7 +554,7 @@ public class WorkerTest { processRecordsLatch.countDown(); return null; } - }).when(recordProcessor).processRecords(any(ProcessRecordsInput.class)); + }).when(v2RecordProcessor).processRecords(any(ProcessRecordsInput.class)); WorkerThread workerThread = runWorker(shardList, initialLeases, @@ -487,7 +562,7 @@ public class WorkerTest { failoverTimeMillis, numberOfRecordsPerShard, fileBasedProxy, - recordProcessorFactory, + v2RecordProcessorFactory, executorService, nullMetricsFactory); @@ -495,16 +570,16 @@ public class WorkerTest { processRecordsLatch.await(); // Make sure record processor is initialized and processing records. - verify(recordProcessorFactory, times(1)).createProcessor(); - verify(recordProcessor, times(1)).initialize(any(InitializationInput.class)); - verify(recordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); - verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); + verify(v2RecordProcessorFactory, times(1)).createProcessor(); + verify(v2RecordProcessor, times(1)).initialize(any(InitializationInput.class)); + verify(v2RecordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); + verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class)); workerThread.getWorker().shutdown(); workerThread.join(); Assert.assertTrue(workerThread.getState() == State.TERMINATED); - verify(recordProcessor, times(1)).shutdown(any(ShutdownInput.class)); + verify(v2RecordProcessor, times(1)).shutdown(any(ShutdownInput.class)); } /** @@ -538,9 +613,7 @@ public class WorkerTest { // Make test case as efficient as possible. final CountDownLatch processRecordsLatch = new CountDownLatch(1); final AtomicBoolean recordProcessorInterrupted = new AtomicBoolean(false); - IRecordProcessorFactory recordProcessorFactory = mock(IRecordProcessorFactory.class); - IRecordProcessor recordProcessor = mock(IRecordProcessor.class); - when(recordProcessorFactory.createProcessor()).thenReturn(recordProcessor); + when(v2RecordProcessorFactory.createProcessor()).thenReturn(v2RecordProcessor); final Semaphore actionBlocker = new Semaphore(1); final Semaphore shutdownBlocker = new Semaphore(1); @@ -572,7 +645,7 @@ public class WorkerTest { return null; } - }).when(recordProcessor).processRecords(any(ProcessRecordsInput.class)); + }).when(v2RecordProcessor).processRecords(any(ProcessRecordsInput.class)); WorkerThread workerThread = runWorker(shardList, initialLeases, @@ -580,7 +653,7 @@ public class WorkerTest { failoverTimeMillis, numberOfRecordsPerShard, fileBasedProxy, - recordProcessorFactory, + v2RecordProcessorFactory, executorService, nullMetricsFactory); @@ -588,17 +661,17 @@ public class WorkerTest { processRecordsLatch.await(); // Make sure record processor is initialized and processing records. - verify(recordProcessorFactory, times(1)).createProcessor(); - verify(recordProcessor, times(1)).initialize(any(InitializationInput.class)); - verify(recordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); - verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); + verify(v2RecordProcessorFactory, times(1)).createProcessor(); + verify(v2RecordProcessor, times(1)).initialize(any(InitializationInput.class)); + verify(v2RecordProcessor, atLeast(1)).processRecords(any(ProcessRecordsInput.class)); + verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class)); workerThread.getWorker().shutdown(); workerThread.join(); Assert.assertTrue(workerThread.getState() == State.TERMINATED); // Shutdown should not be called in this case because record processor is blocked. - verify(recordProcessor, times(0)).shutdown(any(ShutdownInput.class)); + verify(v2RecordProcessor, times(0)).shutdown(any(ShutdownInput.class)); // // Release the worker thread