diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java index 95064f2c..4d0f01ee 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java @@ -19,6 +19,7 @@ import static org.hamcrest.CoreMatchers.nullValue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -125,7 +126,7 @@ public class ShardConsumerSubscriberTest { processedNotifier.wait(5000); } - verify(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); } @Test @@ -139,7 +140,7 @@ public class ShardConsumerSubscriberTest { processedNotifier.wait(5000); } - verify(shardConsumer, times(100)).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); } @Test @@ -174,7 +175,7 @@ public class ShardConsumerSubscriberTest { assertThat(subscriber.getAndResetDispatchFailure(), equalTo(testException)); assertThat(subscriber.getAndResetDispatchFailure(), nullValue()); - verify(shardConsumer, times(20)).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer, times(20)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); } @@ -199,7 +200,7 @@ public class ShardConsumerSubscriberTest { Thread.sleep(10); } - verify(shardConsumer, times(10)).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer, times(10)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); assertThat(subscriber.retrievalFailure(), equalTo(expected)); } @@ -235,7 +236,7 @@ public class ShardConsumerSubscriberTest { } assertThat(recordsPublisher.restartedFrom, equalTo(edgeRecord)); - verify(shardConsumer, times(20)).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer, times(20)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); } @Test @@ -299,7 +300,7 @@ public class ShardConsumerSubscriberTest { processedNotifier.wait(5000); } - verify(shardConsumer, times(100)).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); assertThat(received.size(), equalTo(recordsPublisher.responses.size())); Stream.iterate(0, i -> i + 1).limit(received.size()).forEach(i -> assertThat(received.get(i), @@ -338,7 +339,7 @@ public class ShardConsumerSubscriberTest { // Verifying that there are no interactions with shardConsumer mock indicating no records were sent back and // subscription has not started correctly. - verify(shardConsumer, never()).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer, never()).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); Stream.iterate(2, i -> i + 1).limit(98).forEach(this::addUniqueItem); @@ -398,7 +399,7 @@ public class ShardConsumerSubscriberTest { // Verifying that there are no interactions with shardConsumer mock indicating no records were sent back and // subscription has not started correctly. - verify(shardConsumer, never()).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer, never()).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); Stream.iterate(2, i -> i + 1).limit(98).forEach(this::addUniqueItem); @@ -414,7 +415,7 @@ public class ShardConsumerSubscriberTest { } // Verify that shardConsumer mock was called 100 times and all 100 input records are processed. - verify(shardConsumer, times(100)).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); // Verify that received records in the subscriber are equal to the ones sent by the record publisher. assertThat(received.size(), equalTo(recordsPublisher.responses.size()));