diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java index 0e82126b..ac367a99 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java @@ -1,7 +1,6 @@ package software.amazon.kinesis.retrieval.fanout; import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.notNullValue; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; @@ -10,7 +9,6 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import java.nio.ByteBuffer; import java.time.Instant; @@ -201,20 +199,15 @@ public class FanOutRecordsPublisherTest { ArgumentCaptor flowCaptor = ArgumentCaptor .forClass(FanOutRecordsPublisher.RecordFlow.class); - - when(kinesisClient.subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture())) - .thenThrow(new RuntimeException(ResourceNotFoundException.builder().build())); + ArgumentCaptor inputCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); Subscriber subscriber = spy(new Subscriber() { @Override public void onSubscribe(final Subscription subscription) { - assertThat(subscription, notNullValue()); } @Override public void onNext(final ProcessRecordsInput processRecordsInput) { - assertThat(processRecordsInput.isAtShardEnd(), equalTo(true)); - assertThat(processRecordsInput.records().isEmpty(), equalTo(true)); } @Override @@ -229,9 +222,17 @@ public class FanOutRecordsPublisherTest { source.subscribe(subscriber); verify(kinesisClient).subscribeToShard(any(SubscribeToShardRequest.class), flowCaptor.capture()); - flowCaptor.getValue().onEventStream(publisher); - verify(subscriber).onComplete(); + FanOutRecordsPublisher.RecordFlow recordFlow = flowCaptor.getValue(); + recordFlow.exceptionOccurred(new RuntimeException(ResourceNotFoundException.builder().build())); + + verify(subscriber).onSubscribe(any()); verify(subscriber, never()).onError(any()); + verify(subscriber).onNext(inputCaptor.capture()); + verify(subscriber).onComplete(); + + ProcessRecordsInput input = inputCaptor.getValue(); + assertThat(input.isAtShardEnd(), equalTo(true)); + assertThat(input.records().isEmpty(), equalTo(true)); } private Record makeRecord(int sequenceNumber) {