Adding default getrecordexecutor.

This commit is contained in:
Sahil Palvia 2017-08-29 14:41:36 -07:00
parent 1ec0b656c9
commit 5d1d38b5a1
5 changed files with 52 additions and 13 deletions

View file

@ -0,0 +1,19 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import lombok.Data;
import lombok.NonNull;
/**
*
*/
@Data
public class DefaultGetRecordsExecutor implements GetRecordsExecutor {
@NonNull
private final KinesisDataFetcher dataFetcher;
@Override
public GetRecordsResult getRecords(final int maxRecords) {
return dataFetcher.getRecords(maxRecords);
}
}

View file

@ -0,0 +1,10 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
/**
*
*/
public interface GetRecordsExecutor {
GetRecordsResult getRecords(int maxRecords);
}

View file

@ -62,6 +62,8 @@ class ProcessTask implements ITask {
private final Shard shard; private final Shard shard;
private final ThrottlingReporter throttlingReporter; private final ThrottlingReporter throttlingReporter;
private final GetRecordsExecutor getRecordsExecutor;
/** /**
* @param shardInfo * @param shardInfo
* contains information about the shard * contains information about the shard
@ -81,7 +83,7 @@ class ProcessTask implements ITask {
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) { long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist) {
this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis, this(shardInfo, streamConfig, recordProcessor, recordProcessorCheckpointer, dataFetcher, backoffTimeMillis,
skipShardSyncAtWorkerInitializationIfLeasesExist, skipShardSyncAtWorkerInitializationIfLeasesExist,
new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId())); new ThrottlingReporter(MAX_CONSECUTIVE_THROTTLES, shardInfo.getShardId()), new DefaultGetRecordsExecutor(dataFetcher));
} }
/** /**
@ -103,7 +105,7 @@ class ProcessTask implements ITask {
public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor, public ProcessTask(ShardInfo shardInfo, StreamConfig streamConfig, IRecordProcessor recordProcessor,
RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher, RecordProcessorCheckpointer recordProcessorCheckpointer, KinesisDataFetcher dataFetcher,
long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist, long backoffTimeMillis, boolean skipShardSyncAtWorkerInitializationIfLeasesExist,
ThrottlingReporter throttlingReporter) { ThrottlingReporter throttlingReporter, GetRecordsExecutor getRecordsExecutor) {
super(); super();
this.shardInfo = shardInfo; this.shardInfo = shardInfo;
this.recordProcessor = recordProcessor; this.recordProcessor = recordProcessor;
@ -113,6 +115,7 @@ class ProcessTask implements ITask {
this.backoffTimeMillis = backoffTimeMillis; this.backoffTimeMillis = backoffTimeMillis;
this.throttlingReporter = throttlingReporter; this.throttlingReporter = throttlingReporter;
IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy(); IKinesisProxy kinesisProxy = this.streamConfig.getStreamProxy();
this.getRecordsExecutor = getRecordsExecutor;
// If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for // If skipShardSyncAtWorkerInitializationIfLeasesExist is set, we will not get the shard for
// this ProcessTask. In this case, duplicate KPL user records in the event of resharding will // this ProcessTask. In this case, duplicate KPL user records in the event of resharding will
// not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if // not be dropped during deaggregation of Amazon Kinesis records. This is only applicable if
@ -368,7 +371,7 @@ class ProcessTask implements ITask {
* @return list of data records from Kinesis * @return list of data records from Kinesis
*/ */
private GetRecordsResult getRecordsResultAndRecordMillisBehindLatest() { private GetRecordsResult getRecordsResultAndRecordMillisBehindLatest() {
final GetRecordsResult getRecordsResult = dataFetcher.getRecords(streamConfig.getMaxRecords()); final GetRecordsResult getRecordsResult = getRecordsExecutor.getRecords(streamConfig.getMaxRecords());
if (getRecordsResult == null) { if (getRecordsResult == null) {
// Stream no longer exists // Stream no longer exists

View file

@ -117,6 +117,7 @@ public class KinesisDataFetcherTest {
ICheckpoint checkpoint = mock(ICheckpoint.class); ICheckpoint checkpoint = mock(ICheckpoint.class);
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO); KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO);
GetRecordsExecutor getRecordsExecutor = new DefaultGetRecordsExecutor(fetcher);
String iteratorA = "foo"; String iteratorA = "foo";
String iteratorB = "bar"; String iteratorB = "bar";
@ -138,10 +139,10 @@ public class KinesisDataFetcherTest {
fetcher.initialize(seqA, null); fetcher.initialize(seqA, null);
fetcher.advanceIteratorTo(seqA, null); fetcher.advanceIteratorTo(seqA, null);
Assert.assertEquals(recordsA, fetcher.getRecords(MAX_RECORDS).getRecords()); Assert.assertEquals(recordsA, getRecordsExecutor.getRecords(MAX_RECORDS).getRecords());
fetcher.advanceIteratorTo(seqB, null); fetcher.advanceIteratorTo(seqB, null);
Assert.assertEquals(recordsB, fetcher.getRecords(MAX_RECORDS).getRecords()); Assert.assertEquals(recordsB, getRecordsExecutor.getRecords(MAX_RECORDS).getRecords());
} }
@Test @Test
@ -181,8 +182,9 @@ public class KinesisDataFetcherTest {
// Create data fectcher and initialize it with latest type checkpoint // Create data fectcher and initialize it with latest type checkpoint
KinesisDataFetcher dataFetcher = new KinesisDataFetcher(mockProxy, SHARD_INFO); KinesisDataFetcher dataFetcher = new KinesisDataFetcher(mockProxy, SHARD_INFO);
dataFetcher.initialize(SentinelCheckpoint.LATEST.toString(), INITIAL_POSITION_LATEST); dataFetcher.initialize(SentinelCheckpoint.LATEST.toString(), INITIAL_POSITION_LATEST);
GetRecordsExecutor getRecordsExecutor = new DefaultGetRecordsExecutor(dataFetcher);
// Call getRecords of dataFetcher which will throw an exception // Call getRecords of dataFetcher which will throw an exception
dataFetcher.getRecords(maxRecords); getRecordsExecutor.getRecords(maxRecords);
// Test shard has reached the end // Test shard has reached the end
Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached()); Assert.assertTrue("Shard should reach the end", dataFetcher.isShardEndReached());
@ -206,8 +208,9 @@ public class KinesisDataFetcherTest {
when(checkpoint.getCheckpoint(SHARD_ID)).thenReturn(new ExtendedSequenceNumber(seqNo)); when(checkpoint.getCheckpoint(SHARD_ID)).thenReturn(new ExtendedSequenceNumber(seqNo));
KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO); KinesisDataFetcher fetcher = new KinesisDataFetcher(kinesis, SHARD_INFO);
GetRecordsExecutor getRecordsExecutor = new DefaultGetRecordsExecutor(fetcher);
fetcher.initialize(seqNo, initialPositionInStream); fetcher.initialize(seqNo, initialPositionInStream);
List<Record> actualRecords = fetcher.getRecords(MAX_RECORDS).getRecords(); List<Record> actualRecords = getRecordsExecutor.getRecords(MAX_RECORDS).getRecords();
Assert.assertEquals(expectedRecords, actualRecords); Assert.assertEquals(expectedRecords, actualRecords);
} }

View file

@ -19,7 +19,7 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -76,6 +76,8 @@ public class ProcessTaskTest {
private @Mock RecordProcessorCheckpointer mockCheckpointer; private @Mock RecordProcessorCheckpointer mockCheckpointer;
@Mock @Mock
private ThrottlingReporter throttlingReporter; private ThrottlingReporter throttlingReporter;
@Mock
private GetRecordsExecutor mockGetRecordsExecutor;
private List<Record> processedRecords; private List<Record> processedRecords;
private ExtendedSequenceNumber newLargestPermittedCheckpointValue; private ExtendedSequenceNumber newLargestPermittedCheckpointValue;
@ -94,19 +96,20 @@ public class ProcessTaskTest {
final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null); final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null);
processTask = new ProcessTask( processTask = new ProcessTask(
shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis, shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis,
KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter); KinesisClientLibConfiguration.DEFAULT_SKIP_SHARD_SYNC_AT_STARTUP_IF_LEASES_EXIST, throttlingReporter, mockGetRecordsExecutor);
} }
@Test @Test
public void testProcessTaskWithProvisionedThroughputExceededException() { public void testProcessTaskWithProvisionedThroughputExceededException() {
// Set data fetcher to throw exception // Set data fetcher to throw exception
doReturn(false).when(mockDataFetcher).isShardEndReached(); doReturn(false).when(mockDataFetcher).isShardEndReached();
doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockDataFetcher) doThrow(new ProvisionedThroughputExceededException("Test Exception")).when(mockGetRecordsExecutor)
.getRecords(maxRecords); .getRecords(maxRecords);
TaskResult result = processTask.call(); TaskResult result = processTask.call();
verify(throttlingReporter).throttled(); verify(throttlingReporter).throttled();
verify(throttlingReporter, never()).success(); verify(throttlingReporter, never()).success();
verify(mockGetRecordsExecutor).getRecords(eq(maxRecords));
assertTrue("Result should contain ProvisionedThroughputExceededException", assertTrue("Result should contain ProvisionedThroughputExceededException",
result.getException() instanceof ProvisionedThroughputExceededException); result.getException() instanceof ProvisionedThroughputExceededException);
} }
@ -114,9 +117,10 @@ public class ProcessTaskTest {
@Test @Test
public void testProcessTaskWithNonExistentStream() { public void testProcessTaskWithNonExistentStream() {
// Data fetcher returns a null Result when the stream does not exist // Data fetcher returns a null Result when the stream does not exist
doReturn(null).when(mockDataFetcher).getRecords(maxRecords); doReturn(null).when(mockGetRecordsExecutor).getRecords(maxRecords);
TaskResult result = processTask.call(); TaskResult result = processTask.call();
verify(mockGetRecordsExecutor).getRecords(eq(maxRecords));
assertNull("Task should not throw an exception", result.getException()); assertNull("Task should not throw an exception", result.getException());
} }
@ -300,14 +304,14 @@ public class ProcessTaskTest {
private void testWithRecords(List<Record> records, private void testWithRecords(List<Record> records,
ExtendedSequenceNumber lastCheckpointValue, ExtendedSequenceNumber lastCheckpointValue,
ExtendedSequenceNumber largestPermittedCheckpointValue) { ExtendedSequenceNumber largestPermittedCheckpointValue) {
when(mockDataFetcher.getRecords(anyInt())).thenReturn( when(mockGetRecordsExecutor.getRecords(anyInt())).thenReturn(
new GetRecordsResult().withRecords(records)); new GetRecordsResult().withRecords(records));
when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue); when(mockCheckpointer.getLastCheckpointValue()).thenReturn(lastCheckpointValue);
when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue); when(mockCheckpointer.getLargestPermittedCheckpointValue()).thenReturn(largestPermittedCheckpointValue);
processTask.call(); processTask.call();
verify(throttlingReporter).success(); verify(throttlingReporter).success();
verify(throttlingReporter, never()).throttled(); verify(throttlingReporter, never()).throttled();
verify(mockGetRecordsExecutor).getRecords(anyInt());
ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); ArgumentCaptor<ProcessRecordsInput> priCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class);
verify(mockRecordProcessor).processRecords(priCaptor.capture()); verify(mockRecordProcessor).processRecords(priCaptor.capture());
processedRecords = priCaptor.getValue().getRecords(); processedRecords = priCaptor.getValue().getRecords();