Merge pull request #95 from pfifer/shard-prioritization

Allow Prioritization of Parent Shard Tasks
This commit is contained in:
Justin Pfifer 2016-08-17 08:17:04 -07:00 committed by GitHub
commit 17aa07c071
17 changed files with 536 additions and 77 deletions

View file

@ -154,6 +154,10 @@ public class KinesisClientLibConfiguration {
*/ */
public static final int DEFAULT_INITIAL_LEASE_TABLE_WRITE_CAPACITY = 10; public static final int DEFAULT_INITIAL_LEASE_TABLE_WRITE_CAPACITY = 10;
/**
* Default Shard prioritization strategy.
*/
public static final ShardPrioritization DEFAULT_SHARD_PRIORITIZATION = new NoOpShardPrioritization();
private String applicationName; private String applicationName;
private String tableName; private String tableName;
@ -187,6 +191,7 @@ public class KinesisClientLibConfiguration {
private int initialLeaseTableReadCapacity; private int initialLeaseTableReadCapacity;
private int initialLeaseTableWriteCapacity; private int initialLeaseTableWriteCapacity;
private InitialPositionInStreamExtended initialPositionInStreamExtended; private InitialPositionInStreamExtended initialPositionInStreamExtended;
private ShardPrioritization shardPrioritization;
/** /**
* Constructor. * Constructor.
@ -333,6 +338,7 @@ public class KinesisClientLibConfiguration {
this.initialLeaseTableWriteCapacity = DEFAULT_INITIAL_LEASE_TABLE_WRITE_CAPACITY; this.initialLeaseTableWriteCapacity = DEFAULT_INITIAL_LEASE_TABLE_WRITE_CAPACITY;
this.initialPositionInStreamExtended = this.initialPositionInStreamExtended =
InitialPositionInStreamExtended.newInitialPosition(initialPositionInStream); InitialPositionInStreamExtended.newInitialPosition(initialPositionInStream);
this.shardPrioritization = DEFAULT_SHARD_PRIORITIZATION;
} }
// Check if value is positive, otherwise throw an exception // Check if value is positive, otherwise throw an exception
@ -599,6 +605,13 @@ public class KinesisClientLibConfiguration {
return initialPositionInStreamExtended.getTimestamp(); return initialPositionInStreamExtended.getTimestamp();
} }
/**
* @return Shard prioritization strategy.
*/
public ShardPrioritization getShardPrioritizationStrategy() {
return shardPrioritization;
}
// CHECKSTYLE:IGNORE HiddenFieldCheck FOR NEXT 190 LINES // CHECKSTYLE:IGNORE HiddenFieldCheck FOR NEXT 190 LINES
/** /**
* @param tableName name of the lease table in DynamoDB * @param tableName name of the lease table in DynamoDB
@ -913,4 +926,16 @@ public class KinesisClientLibConfiguration {
this.initialLeaseTableWriteCapacity = initialLeaseTableWriteCapacity; this.initialLeaseTableWriteCapacity = initialLeaseTableWriteCapacity;
return this; return this;
} }
/**
* @param shardPrioritization Implementation of ShardPrioritization interface that should be used during processing.
* @return KinesisClientLibConfiguration
*/
public KinesisClientLibConfiguration withShardPrioritizationStrategy(ShardPrioritization shardPrioritization) {
if (shardPrioritization == null) {
throw new IllegalArgumentException("shardPrioritization cannot be null");
}
this.shardPrioritization = shardPrioritization;
return this;
}
} }

View file

@ -209,7 +209,8 @@ class KinesisClientLibLeaseCoordinator extends LeaseCoordinator<KinesisClientLea
new ShardInfo( new ShardInfo(
lease.getLeaseKey(), lease.getLeaseKey(),
lease.getConcurrencyToken().toString(), lease.getConcurrencyToken().toString(),
parentShardIds); parentShardIds,
lease.getCheckpoint());
assignments.add(assignment); assignments.add(assignment);
} }
} }

View file

@ -0,0 +1,21 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.List;
/**
* Shard Prioritization that returns the same original list of shards without any modifications.
*/
public class NoOpShardPrioritization implements
ShardPrioritization {
/**
* Empty constructor for NoOp Shard Prioritization.
*/
public NoOpShardPrioritization() {
}
@Override
public List<ShardInfo> prioritize(List<ShardInfo> original) {
return original;
}
}

View file

@ -0,0 +1,135 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Shard Prioritization that prioritizes parent shards first.
* It also limits number of shards that will be available for initialization based on their depth.
* It doesn't make a lot of sense to work on a shard that has too many unfinished parents.
*/
public class ParentsFirstShardPrioritization implements
ShardPrioritization {
private static final SortingNode PROCESSING_NODE = new SortingNode(null, Integer.MIN_VALUE);
private final int maxDepth;
/**
* Creates ParentFirst prioritization with filtering based on depth of the shard.
* Shards that have depth > maxDepth will be ignored and will not be returned by this prioritization.
*
* @param maxDepth any shard that is deeper than max depth, will be excluded from processing
*/
public ParentsFirstShardPrioritization(int maxDepth) {
/* Depth 0 means that shard is completed or cannot be found,
* it is impossible to process such shards.
*/
if (maxDepth <= 0) {
throw new IllegalArgumentException("Max depth cannot be negative or zero. Provided value: " + maxDepth);
}
this.maxDepth = maxDepth;
}
@Override
public List<ShardInfo> prioritize(List<ShardInfo> original) {
Map<String, ShardInfo> shards = new HashMap<>();
for (ShardInfo shardInfo : original) {
shards.put(shardInfo.getShardId(),
shardInfo);
}
Map<String, SortingNode> processedNodes = new HashMap<>();
for (ShardInfo shardInfo : original) {
populateDepth(shardInfo.getShardId(),
shards,
processedNodes);
}
List<ShardInfo> orderedInfos = new ArrayList<>(original.size());
List<SortingNode> orderedNodes = new ArrayList<>(processedNodes.values());
Collections.sort(orderedNodes);
for (SortingNode sortingTreeNode : orderedNodes) {
// don't process shards with depth > maxDepth
if (sortingTreeNode.getDepth() <= maxDepth) {
orderedInfos.add(sortingTreeNode.shardInfo);
}
}
return orderedInfos;
}
private int populateDepth(String shardId,
Map<String, ShardInfo> shards,
Map<String, SortingNode> processedNodes) {
SortingNode processed = processedNodes.get(shardId);
if (processed != null) {
if (processed == PROCESSING_NODE) {
throw new IllegalArgumentException("Circular dependency detected. Shard Id "
+ shardId + " is processed twice");
}
return processed.getDepth();
}
ShardInfo shardInfo = shards.get(shardId);
if (shardInfo == null) {
// parent doesn't exist in our list, so this shard is root-level node
return 0;
}
if (shardInfo.isCompleted()) {
// we treat completed shards as 0-level
return 0;
}
// storing processing node to make sure we track progress and avoid circular dependencies
processedNodes.put(shardId, PROCESSING_NODE);
int maxParentDepth = 0;
for (String parentId : shardInfo.getParentShardIds()) {
maxParentDepth = Math.max(maxParentDepth,
populateDepth(parentId,
shards,
processedNodes));
}
int currentNodeLevel = maxParentDepth + 1;
SortingNode previousValue = processedNodes.put(shardId,
new SortingNode(shardInfo,
currentNodeLevel));
if (previousValue != PROCESSING_NODE) {
throw new IllegalStateException("Validation failed. Depth for shardId " + shardId + " was populated twice");
}
return currentNodeLevel;
}
/**
* Class to store depth of shards during prioritization.
*/
private static class SortingNode implements
Comparable<SortingNode> {
private final ShardInfo shardInfo;
private final int depth;
public SortingNode(ShardInfo shardInfo,
int depth) {
this.shardInfo = shardInfo;
this.depth = depth;
}
public int getDepth() {
return depth;
}
@Override
public int compareTo(SortingNode o) {
return Integer.compare(depth,
o.depth);
}
}
}

View file

@ -297,36 +297,35 @@ class ShardConsumer {
*/ */
// CHECKSTYLE:OFF CyclomaticComplexity // CHECKSTYLE:OFF CyclomaticComplexity
void updateState(boolean taskCompletedSuccessfully) { void updateState(boolean taskCompletedSuccessfully) {
if (currentState == ShardConsumerState.SHUTDOWN_COMPLETE) {
// Shutdown was completed and there nothing we can do after that
return;
}
if ((currentTask == null) && beginShutdown) {
// Shard didn't start any tasks and can be shutdown fast
currentState = ShardConsumerState.SHUTDOWN_COMPLETE;
return;
}
if (beginShutdown && currentState != ShardConsumerState.SHUTTING_DOWN) {
// Shard received signal to start shutdown.
// Whatever task we were working on should be stopped and shutdown task should be executed
currentState = ShardConsumerState.SHUTTING_DOWN;
return;
}
switch (currentState) { switch (currentState) {
case WAITING_ON_PARENT_SHARDS: case WAITING_ON_PARENT_SHARDS:
if (taskCompletedSuccessfully && TaskType.BLOCK_ON_PARENT_SHARDS.equals(currentTask.getTaskType())) { if (taskCompletedSuccessfully && TaskType.BLOCK_ON_PARENT_SHARDS.equals(currentTask.getTaskType())) {
if (beginShutdown) { currentState = ShardConsumerState.INITIALIZING;
currentState = ShardConsumerState.SHUTTING_DOWN;
} else {
currentState = ShardConsumerState.INITIALIZING;
}
} else if ((currentTask == null) && beginShutdown) {
currentState = ShardConsumerState.SHUTDOWN_COMPLETE;
} }
break; break;
case INITIALIZING: case INITIALIZING:
if (taskCompletedSuccessfully && TaskType.INITIALIZE.equals(currentTask.getTaskType())) { if (taskCompletedSuccessfully && TaskType.INITIALIZE.equals(currentTask.getTaskType())) {
if (beginShutdown) { currentState = ShardConsumerState.PROCESSING;
currentState = ShardConsumerState.SHUTTING_DOWN;
} else {
currentState = ShardConsumerState.PROCESSING;
}
} else if ((currentTask == null) && beginShutdown) {
currentState = ShardConsumerState.SHUTDOWN_COMPLETE;
} }
break; break;
case PROCESSING: case PROCESSING:
if (taskCompletedSuccessfully && TaskType.PROCESS.equals(currentTask.getTaskType())) { if (taskCompletedSuccessfully && TaskType.PROCESS.equals(currentTask.getTaskType())) {
if (beginShutdown) { currentState = ShardConsumerState.PROCESSING;
currentState = ShardConsumerState.SHUTTING_DOWN;
} else {
currentState = ShardConsumerState.PROCESSING;
}
} }
break; break;
case SHUTTING_DOWN: case SHUTTING_DOWN:
@ -335,8 +334,6 @@ class ShardConsumer {
currentState = ShardConsumerState.SHUTDOWN_COMPLETE; currentState = ShardConsumerState.SHUTDOWN_COMPLETE;
} }
break; break;
case SHUTDOWN_COMPLETE:
break;
default: default:
LOG.error("Unexpected state: " + currentState); LOG.error("Unexpected state: " + currentState);
break; break;

View file

@ -19,6 +19,8 @@ import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
/** /**
* Used to pass shard related info among different classes and as a key to the map of shard consumers. * Used to pass shard related info among different classes and as a key to the map of shard consumers.
*/ */
@ -28,13 +30,18 @@ class ShardInfo {
private final String concurrencyToken; private final String concurrencyToken;
// Sorted list of parent shardIds. // Sorted list of parent shardIds.
private final List<String> parentShardIds; private final List<String> parentShardIds;
private final ExtendedSequenceNumber checkpoint;
/** /**
* @param shardId Kinesis shardId * @param shardId Kinesis shardId
* @param concurrencyToken Used to differentiate between lost and reclaimed leases * @param concurrencyToken Used to differentiate between lost and reclaimed leases
* @param parentShardIds Parent shards of the shard identified by Kinesis shardId * @param parentShardIds Parent shards of the shard identified by Kinesis shardId
* @param checkpoint the latest checkpoint from lease
*/ */
public ShardInfo(String shardId, String concurrencyToken, Collection<String> parentShardIds) { public ShardInfo(String shardId,
String concurrencyToken,
Collection<String> parentShardIds,
ExtendedSequenceNumber checkpoint) {
this.shardId = shardId; this.shardId = shardId;
this.concurrencyToken = concurrencyToken; this.concurrencyToken = concurrencyToken;
this.parentShardIds = new LinkedList<String>(); this.parentShardIds = new LinkedList<String>();
@ -44,6 +51,7 @@ class ShardInfo {
// ShardInfo stores parent shard Ids in canonical order in the parentShardIds list. // ShardInfo stores parent shard Ids in canonical order in the parentShardIds list.
// This makes it easy to check for equality in ShardInfo.equals method. // This makes it easy to check for equality in ShardInfo.equals method.
Collections.sort(this.parentShardIds); Collections.sort(this.parentShardIds);
this.checkpoint = checkpoint;
} }
/** /**
@ -67,6 +75,13 @@ class ShardInfo {
return new LinkedList<String>(parentShardIds); return new LinkedList<String>(parentShardIds);
} }
/**
* @return completion status of the shard
*/
protected boolean isCompleted() {
return ExtendedSequenceNumber.SHARD_END.equals(checkpoint);
}
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@ -77,6 +92,7 @@ class ShardInfo {
result = prime * result + ((concurrencyToken == null) ? 0 : concurrencyToken.hashCode()); result = prime * result + ((concurrencyToken == null) ? 0 : concurrencyToken.hashCode());
result = prime * result + ((parentShardIds == null) ? 0 : parentShardIds.hashCode()); result = prime * result + ((parentShardIds == null) ? 0 : parentShardIds.hashCode());
result = prime * result + ((shardId == null) ? 0 : shardId.hashCode()); result = prime * result + ((shardId == null) ? 0 : shardId.hashCode());
result = prime * result + ((checkpoint == null) ? 0 : checkpoint.hashCode());
return result; return result;
} }
@ -126,6 +142,13 @@ class ShardInfo {
} else if (!shardId.equals(other.shardId)) { } else if (!shardId.equals(other.shardId)) {
return false; return false;
} }
if (checkpoint == null) {
if (other.checkpoint != null) {
return false;
}
} else if (!checkpoint.equals(other.checkpoint)) {
return false;
}
return true; return true;
} }
@ -135,7 +158,44 @@ class ShardInfo {
@Override @Override
public String toString() { public String toString() {
return "ShardInfo [shardId=" + shardId + ", concurrencyToken=" + concurrencyToken + ", parentShardIds=" return "ShardInfo [shardId=" + shardId + ", concurrencyToken=" + concurrencyToken + ", parentShardIds="
+ parentShardIds + "]"; + parentShardIds + ", checkpoint=" + checkpoint + "]";
}
/**
* Builder class for ShardInfo.
*/
public static class Builder {
private String shardId;
private String concurrencyToken;
private List<String> parentShardIds = Collections.emptyList();
private ExtendedSequenceNumber checkpoint = ExtendedSequenceNumber.LATEST;
public Builder() {
}
public Builder withShardId(String shardId) {
this.shardId = shardId;
return this;
}
public Builder withConcurrencyToken(String concurrencyToken) {
this.concurrencyToken = concurrencyToken;
return this;
}
public Builder withParentShards(List<String> parentShardIds) {
this.parentShardIds = parentShardIds;
return this;
}
public Builder withCheckpoint(ExtendedSequenceNumber checkpoint) {
this.checkpoint = checkpoint;
return this;
}
public ShardInfo build() {
return new ShardInfo(shardId, concurrencyToken, parentShardIds, checkpoint);
}
} }
} }

View file

@ -0,0 +1,19 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import java.util.List;
/**
* Provides logic to prioritize or filter shards before their execution.
*/
public interface ShardPrioritization {
/**
* Returns new list of shards ordered based on their priority.
* Resulted list may have fewer shards compared to original list
*
* @param original
* list of shards needed to be prioritized
* @return new list that contains only shards that should be processed
*/
List<ShardInfo> prioritize(List<ShardInfo> original);
}

View file

@ -38,6 +38,7 @@ import com.amazonaws.services.kinesis.AmazonKinesisClient;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.ICheckpoint; import com.amazonaws.services.kinesis.clientlibrary.interfaces.ICheckpoint;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor;
import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessorFactory; import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessorFactory;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker.Builder;
import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxyFactory; import com.amazonaws.services.kinesis.clientlibrary.proxies.KinesisProxyFactory;
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason; import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason;
import com.amazonaws.services.kinesis.leases.exceptions.LeasingException; import com.amazonaws.services.kinesis.leases.exceptions.LeasingException;
@ -80,6 +81,8 @@ public class Worker implements Runnable {
private final KinesisClientLibLeaseCoordinator leaseCoordinator; private final KinesisClientLibLeaseCoordinator leaseCoordinator;
private final ShardSyncTaskManager controlServer; private final ShardSyncTaskManager controlServer;
private final ShardPrioritization shardPrioritization;
private volatile boolean shutdown; private volatile boolean shutdown;
private volatile long shutdownStartTimeMillis; private volatile long shutdownStartTimeMillis;
@ -231,7 +234,8 @@ public class Worker implements Runnable {
execService, execService,
metricsFactory, metricsFactory,
config.getTaskBackoffTimeMillis(), config.getTaskBackoffTimeMillis(),
config.getFailoverTimeMillis()); config.getFailoverTimeMillis(),
config.getShardPrioritizationStrategy());
// If a region name was explicitly specified, use it as the region for Amazon Kinesis and Amazon DynamoDB. // If a region name was explicitly specified, use it as the region for Amazon Kinesis and Amazon DynamoDB.
if (config.getRegionName() != null) { if (config.getRegionName() != null) {
Region region = RegionUtils.getRegion(config.getRegionName()); Region region = RegionUtils.getRegion(config.getRegionName());
@ -271,6 +275,7 @@ public class Worker implements Runnable {
* consumption) * consumption)
* @param metricsFactory Metrics factory used to emit metrics * @param metricsFactory Metrics factory used to emit metrics
* @param taskBackoffTimeMillis Backoff period when tasks encounter an exception * @param taskBackoffTimeMillis Backoff period when tasks encounter an exception
* @param shardPrioritization Provides prioritization logic to decide which available shards process first
*/ */
// NOTE: This has package level access solely for testing // NOTE: This has package level access solely for testing
// CHECKSTYLE:IGNORE ParameterNumber FOR NEXT 10 LINES // CHECKSTYLE:IGNORE ParameterNumber FOR NEXT 10 LINES
@ -286,7 +291,8 @@ public class Worker implements Runnable {
ExecutorService execService, ExecutorService execService,
IMetricsFactory metricsFactory, IMetricsFactory metricsFactory,
long taskBackoffTimeMillis, long taskBackoffTimeMillis,
long failoverTimeMillis) { long failoverTimeMillis,
ShardPrioritization shardPrioritization) {
this.applicationName = applicationName; this.applicationName = applicationName;
this.recordProcessorFactory = recordProcessorFactory; this.recordProcessorFactory = recordProcessorFactory;
this.streamConfig = streamConfig; this.streamConfig = streamConfig;
@ -308,6 +314,7 @@ public class Worker implements Runnable {
executorService); executorService);
this.taskBackoffTimeMillis = taskBackoffTimeMillis; this.taskBackoffTimeMillis = taskBackoffTimeMillis;
this.failoverTimeMillis = failoverTimeMillis; this.failoverTimeMillis = failoverTimeMillis;
this.shardPrioritization = shardPrioritization;
} }
/** /**
@ -449,12 +456,13 @@ public class Worker implements Runnable {
private List<ShardInfo> getShardInfoForAssignments() { private List<ShardInfo> getShardInfoForAssignments() {
List<ShardInfo> assignedStreamShards = leaseCoordinator.getCurrentAssignments(); List<ShardInfo> assignedStreamShards = leaseCoordinator.getCurrentAssignments();
List<ShardInfo> prioritizedShards = shardPrioritization.prioritize(assignedStreamShards);
if ((assignedStreamShards != null) && (!assignedStreamShards.isEmpty())) { if ((prioritizedShards != null) && (!prioritizedShards.isEmpty())) {
if (wlog.isInfoEnabled()) { if (wlog.isInfoEnabled()) {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
boolean firstItem = true; boolean firstItem = true;
for (ShardInfo shardInfo : assignedStreamShards) { for (ShardInfo shardInfo : prioritizedShards) {
if (!firstItem) { if (!firstItem) {
builder.append(", "); builder.append(", ");
} }
@ -467,7 +475,7 @@ public class Worker implements Runnable {
wlog.info("No activities assigned"); wlog.info("No activities assigned");
} }
return assignedStreamShards; return prioritizedShards;
} }
/** /**
@ -780,6 +788,7 @@ public class Worker implements Runnable {
private AmazonCloudWatch cloudWatchClient; private AmazonCloudWatch cloudWatchClient;
private IMetricsFactory metricsFactory; private IMetricsFactory metricsFactory;
private ExecutorService execService; private ExecutorService execService;
private ShardPrioritization shardPrioritization;
/** /**
* Default constructor. * Default constructor.
@ -879,6 +888,19 @@ public class Worker implements Runnable {
return this; return this;
} }
/**
* Provides logic how to prioritize shard processing.
*
* @param shardPrioritization
* shardPrioritization is responsible to order shards before processing
*
* @return A reference to this updated object so that method calls can be chained together.
*/
public Builder shardPrioritization(ShardPrioritization shardPrioritization) {
this.shardPrioritization = shardPrioritization;
return this;
}
/** /**
* Build the Worker instance. * Build the Worker instance.
* *
@ -937,6 +959,9 @@ public class Worker implements Runnable {
if (metricsFactory == null) { if (metricsFactory == null) {
metricsFactory = getMetricsFactory(cloudWatchClient, config); metricsFactory = getMetricsFactory(cloudWatchClient, config);
} }
if (shardPrioritization == null) {
shardPrioritization = new ParentsFirstShardPrioritization(1);
}
return new Worker(config.getApplicationName(), return new Worker(config.getApplicationName(),
recordProcessorFactory, recordProcessorFactory,
@ -965,7 +990,8 @@ public class Worker implements Runnable {
execService, execService,
metricsFactory, metricsFactory,
config.getTaskBackoffTimeMillis(), config.getTaskBackoffTimeMillis(),
config.getFailoverTimeMillis()); config.getFailoverTimeMillis(),
shardPrioritization);
} }
} }

View file

@ -20,12 +20,11 @@ import static org.mockito.Mockito.when;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import junit.framework.Assert;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.junit.After; import org.junit.After;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
@ -47,7 +46,7 @@ public class BlockOnParentShardTaskTest {
private final String shardId = "shardId-97"; private final String shardId = "shardId-97";
private final String concurrencyToken = "testToken"; private final String concurrencyToken = "testToken";
private final List<String> emptyParentShardIds = new ArrayList<String>(); private final List<String> emptyParentShardIds = new ArrayList<String>();
ShardInfo defaultShardInfo = new ShardInfo(shardId, concurrencyToken, emptyParentShardIds); ShardInfo defaultShardInfo = new ShardInfo(shardId, concurrencyToken, emptyParentShardIds, ExtendedSequenceNumber.TRIM_HORIZON);
/** /**
* @throws java.lang.Exception * @throws java.lang.Exception
@ -122,14 +121,14 @@ public class BlockOnParentShardTaskTest {
// test single parent // test single parent
parentShardIds.add(parent1ShardId); parentShardIds.add(parent1ShardId);
shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds); shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds, ExtendedSequenceNumber.TRIM_HORIZON);
task = new BlockOnParentShardTask(shardInfo, leaseManager, backoffTimeInMillis); task = new BlockOnParentShardTask(shardInfo, leaseManager, backoffTimeInMillis);
result = task.call(); result = task.call();
Assert.assertNull(result.getException()); Assert.assertNull(result.getException());
// test two parents // test two parents
parentShardIds.add(parent2ShardId); parentShardIds.add(parent2ShardId);
shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds); shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds, ExtendedSequenceNumber.TRIM_HORIZON);
task = new BlockOnParentShardTask(shardInfo, leaseManager, backoffTimeInMillis); task = new BlockOnParentShardTask(shardInfo, leaseManager, backoffTimeInMillis);
result = task.call(); result = task.call();
Assert.assertNull(result.getException()); Assert.assertNull(result.getException());
@ -164,14 +163,14 @@ public class BlockOnParentShardTaskTest {
// test single parent // test single parent
parentShardIds.add(parent1ShardId); parentShardIds.add(parent1ShardId);
shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds); shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds, ExtendedSequenceNumber.TRIM_HORIZON);
task = new BlockOnParentShardTask(shardInfo, leaseManager, backoffTimeInMillis); task = new BlockOnParentShardTask(shardInfo, leaseManager, backoffTimeInMillis);
result = task.call(); result = task.call();
Assert.assertNotNull(result.getException()); Assert.assertNotNull(result.getException());
// test two parents // test two parents
parentShardIds.add(parent2ShardId); parentShardIds.add(parent2ShardId);
shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds); shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds, ExtendedSequenceNumber.TRIM_HORIZON);
task = new BlockOnParentShardTask(shardInfo, leaseManager, backoffTimeInMillis); task = new BlockOnParentShardTask(shardInfo, leaseManager, backoffTimeInMillis);
result = task.call(); result = task.call();
Assert.assertNotNull(result.getException()); Assert.assertNotNull(result.getException());
@ -191,7 +190,7 @@ public class BlockOnParentShardTaskTest {
String parentShardId = "shardId-1"; String parentShardId = "shardId-1";
List<String> parentShardIds = new ArrayList<>(); List<String> parentShardIds = new ArrayList<>();
parentShardIds.add(parentShardId); parentShardIds.add(parentShardId);
ShardInfo shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds); ShardInfo shardInfo = new ShardInfo(shardId, concurrencyToken, parentShardIds, ExtendedSequenceNumber.TRIM_HORIZON);
TaskResult result = null; TaskResult result = null;
KinesisClientLease parentLease = new KinesisClientLease(); KinesisClientLease parentLease = new KinesisClientLease();
ILeaseManager<KinesisClientLease> leaseManager = mock(ILeaseManager.class); ILeaseManager<KinesisClientLease> leaseManager = mock(ILeaseManager.class);

View file

@ -48,7 +48,7 @@ public class KinesisDataFetcherTest {
private static final int MAX_RECORDS = 1; private static final int MAX_RECORDS = 1;
private static final String SHARD_ID = "shardId-1"; private static final String SHARD_ID = "shardId-1";
private static final String AT_SEQUENCE_NUMBER = ShardIteratorType.AT_SEQUENCE_NUMBER.toString(); private static final String AT_SEQUENCE_NUMBER = ShardIteratorType.AT_SEQUENCE_NUMBER.toString();
private static final ShardInfo SHARD_INFO = new ShardInfo(SHARD_ID, null, null); private static final ShardInfo SHARD_INFO = new ShardInfo(SHARD_ID, null, null, null);
private static final InitialPositionInStreamExtended INITIAL_POSITION_LATEST = private static final InitialPositionInStreamExtended INITIAL_POSITION_LATEST =
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST); InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST);
private static final InitialPositionInStreamExtended INITIAL_POSITION_TRIM_HORIZON = private static final InitialPositionInStreamExtended INITIAL_POSITION_TRIM_HORIZON =

View file

@ -0,0 +1,162 @@
package com.amazonaws.services.kinesis.clientlibrary.lib.worker;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.junit.Test;
public class ParentsFirstShardPrioritizationUnitTest {
@Test(expected = IllegalArgumentException.class)
public void testMaxDepthNegativeShouldFail() {
new ParentsFirstShardPrioritization(-1);
}
@Test(expected = IllegalArgumentException.class)
public void testMaxDepthZeroShouldFail() {
new ParentsFirstShardPrioritization(0);
}
@Test
public void testMaxDepthPositiveShouldNotFail() {
new ParentsFirstShardPrioritization(1);
}
@Test
public void testSorting() {
Random random = new Random(987654);
int numberOfShards = 7;
List<String> shardIdsDependencies = new ArrayList<>();
shardIdsDependencies.add("unknown");
List<ShardInfo> original = new ArrayList<>();
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
String shardId = shardId(shardNumber);
original.add(shardInfo(shardId, shardIdsDependencies));
shardIdsDependencies.add(shardId);
}
ParentsFirstShardPrioritization ordering = new ParentsFirstShardPrioritization(Integer.MAX_VALUE);
// shuffle original list as it is already ordered in right way
Collections.shuffle(original, random);
List<ShardInfo> ordered = ordering.prioritize(original);
assertEquals(numberOfShards, ordered.size());
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
String shardId = shardId(shardNumber);
assertEquals(shardId, ordered.get(shardNumber).getShardId());
}
}
@Test
public void testSortingAndFiltering() {
Random random = new Random(45677);
int numberOfShards = 10;
List<String> shardIdsDependencies = new ArrayList<>();
shardIdsDependencies.add("unknown");
List<ShardInfo> original = new ArrayList<>();
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
String shardId = shardId(shardNumber);
original.add(shardInfo(shardId, shardIdsDependencies));
shardIdsDependencies.add(shardId);
}
int maxDepth = 3;
ParentsFirstShardPrioritization ordering = new ParentsFirstShardPrioritization(maxDepth);
// shuffle original list as it is already ordered in right way
Collections.shuffle(original, random);
List<ShardInfo> ordered = ordering.prioritize(original);
// in this case every shard has its own level, so we don't expect to
// have more shards than max depth
assertEquals(maxDepth, ordered.size());
for (int shardNumber = 0; shardNumber < maxDepth; shardNumber++) {
String shardId = shardId(shardNumber);
assertEquals(shardId, ordered.get(shardNumber).getShardId());
}
}
@Test
public void testSimpleOrdering() {
Random random = new Random(1234);
int numberOfShards = 10;
String parentId = "unknown";
List<ShardInfo> original = new ArrayList<>();
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
String shardId = shardId(shardNumber);
original.add(shardInfo(shardId, parentId));
parentId = shardId;
}
ParentsFirstShardPrioritization ordering = new ParentsFirstShardPrioritization(Integer.MAX_VALUE);
// shuffle original list as it is already ordered in right way
Collections.shuffle(original, random);
List<ShardInfo> ordered = ordering.prioritize(original);
assertEquals(numberOfShards, ordered.size());
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
String shardId = shardId(shardNumber);
assertEquals(shardId, ordered.get(shardNumber).getShardId());
}
}
/**
* This should be impossible as shards don't have circular dependencies,
* but this code should handle it properly and fail
*/
@Test
public void testCircularDependencyBetweenShards() {
Random random = new Random(13468798);
int numberOfShards = 10;
// shard-0 will point in middle shard (shard-5) in current test
String parentId = shardId(numberOfShards / 2);
List<ShardInfo> original = new ArrayList<>();
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
String shardId = shardId(shardNumber);
original.add(shardInfo(shardId, parentId));
parentId = shardId;
}
ParentsFirstShardPrioritization ordering = new ParentsFirstShardPrioritization(Integer.MAX_VALUE);
// shuffle original list as it is already ordered in right way
Collections.shuffle(original, random);
try {
ordering.prioritize(original);
fail("Processing should fail in case we have circular dependency");
} catch (IllegalArgumentException expected) {
}
}
private String shardId(int shardNumber) {
return "shardId-" + shardNumber;
}
private static ShardInfo shardInfo(String shardId, List<String> parentShardIds) {
// copy into new list just in case ShardInfo will stop doing it
List<String> newParentShardIds = new ArrayList<>(parentShardIds);
return new ShardInfo.Builder()
.withShardId(shardId)
.withParentShards(newParentShardIds)
.build();
}
private static ShardInfo shardInfo(String shardId, String... parentShardIds) {
return new ShardInfo.Builder()
.withShardId(shardId)
.withParentShards(Arrays.asList(parentShardIds))
.build();
}
}

View file

@ -87,7 +87,7 @@ public class ProcessTaskTest {
new StreamConfig(null, maxRecords, idleTimeMillis, callProcessRecordsForEmptyRecordList, new StreamConfig(null, maxRecords, idleTimeMillis, callProcessRecordsForEmptyRecordList,
skipCheckpointValidationValue, skipCheckpointValidationValue,
INITIAL_POSITION_LATEST); INITIAL_POSITION_LATEST);
final ShardInfo shardInfo = new ShardInfo(shardId, null, null); final ShardInfo shardInfo = new ShardInfo(shardId, null, null, null);
processTask = new ProcessTask( processTask = new ProcessTask(
shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis); shardInfo, config, mockRecordProcessor, mockCheckpointer, mockDataFetcher, taskBackoffTimeMillis);
} }

View file

@ -75,7 +75,7 @@ public class RecordProcessorCheckpointerTest {
*/ */
@Test @Test
public final void testCheckpoint() throws Exception { public final void testCheckpoint() throws Exception {
ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
// First call to checkpoint // First call to checkpoint
RecordProcessorCheckpointer processingCheckpointer = RecordProcessorCheckpointer processingCheckpointer =
@ -98,7 +98,7 @@ public class RecordProcessorCheckpointerTest {
*/ */
@Test @Test
public final void testCheckpointRecord() throws Exception { public final void testCheckpointRecord() throws Exception {
ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
SequenceNumberValidator sequenceNumberValidator = SequenceNumberValidator sequenceNumberValidator =
new SequenceNumberValidator(null, shardId, false); new SequenceNumberValidator(null, shardId, false);
RecordProcessorCheckpointer processingCheckpointer = RecordProcessorCheckpointer processingCheckpointer =
@ -117,7 +117,7 @@ public class RecordProcessorCheckpointerTest {
*/ */
@Test @Test
public final void testCheckpointSubRecord() throws Exception { public final void testCheckpointSubRecord() throws Exception {
ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
SequenceNumberValidator sequenceNumberValidator = SequenceNumberValidator sequenceNumberValidator =
new SequenceNumberValidator(null, shardId, false); new SequenceNumberValidator(null, shardId, false);
RecordProcessorCheckpointer processingCheckpointer = RecordProcessorCheckpointer processingCheckpointer =
@ -137,7 +137,7 @@ public class RecordProcessorCheckpointerTest {
*/ */
@Test @Test
public final void testCheckpointSequenceNumber() throws Exception { public final void testCheckpointSequenceNumber() throws Exception {
ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
SequenceNumberValidator sequenceNumberValidator = SequenceNumberValidator sequenceNumberValidator =
new SequenceNumberValidator(null, shardId, false); new SequenceNumberValidator(null, shardId, false);
RecordProcessorCheckpointer processingCheckpointer = RecordProcessorCheckpointer processingCheckpointer =
@ -155,7 +155,7 @@ public class RecordProcessorCheckpointerTest {
*/ */
@Test @Test
public final void testCheckpointExtendedSequenceNumber() throws Exception { public final void testCheckpointExtendedSequenceNumber() throws Exception {
ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
SequenceNumberValidator sequenceNumberValidator = SequenceNumberValidator sequenceNumberValidator =
new SequenceNumberValidator(null, shardId, false); new SequenceNumberValidator(null, shardId, false);
RecordProcessorCheckpointer processingCheckpointer = RecordProcessorCheckpointer processingCheckpointer =
@ -173,7 +173,7 @@ public class RecordProcessorCheckpointerTest {
*/ */
@Test @Test
public final void testUpdate() throws Exception { public final void testUpdate() throws Exception {
ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
RecordProcessorCheckpointer checkpointer = new RecordProcessorCheckpointer(shardInfo, checkpoint, null); RecordProcessorCheckpointer checkpointer = new RecordProcessorCheckpointer(shardInfo, checkpoint, null);
@ -193,7 +193,7 @@ public class RecordProcessorCheckpointerTest {
*/ */
@Test @Test
public final void testClientSpecifiedCheckpoint() throws Exception { public final void testClientSpecifiedCheckpoint() throws Exception {
ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
SequenceNumberValidator validator = mock(SequenceNumberValidator.class); SequenceNumberValidator validator = mock(SequenceNumberValidator.class);
Mockito.doNothing().when(validator).validateSequenceNumber(anyString()); Mockito.doNothing().when(validator).validateSequenceNumber(anyString());
@ -290,7 +290,7 @@ public class RecordProcessorCheckpointerTest {
@SuppressWarnings("serial") @SuppressWarnings("serial")
@Test @Test
public final void testMixedCheckpointCalls() throws Exception { public final void testMixedCheckpointCalls() throws Exception {
ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(shardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
SequenceNumberValidator validator = mock(SequenceNumberValidator.class); SequenceNumberValidator validator = mock(SequenceNumberValidator.class);
Mockito.doNothing().when(validator).validateSequenceNumber(anyString()); Mockito.doNothing().when(validator).validateSequenceNumber(anyString());

View file

@ -92,7 +92,7 @@ public class ShardConsumerTest {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
public final void testInitializationStateUponFailure() throws Exception { public final void testInitializationStateUponFailure() throws Exception {
ShardInfo shardInfo = new ShardInfo("s-0-0", "testToken", null); ShardInfo shardInfo = new ShardInfo("s-0-0", "testToken", null, ExtendedSequenceNumber.TRIM_HORIZON);
ICheckpoint checkpoint = mock(ICheckpoint.class); ICheckpoint checkpoint = mock(ICheckpoint.class);
when(checkpoint.getCheckpoint(anyString())).thenThrow(NullPointerException.class); when(checkpoint.getCheckpoint(anyString())).thenThrow(NullPointerException.class);
@ -141,7 +141,7 @@ public class ShardConsumerTest {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
public final void testInitializationStateUponSubmissionFailure() throws Exception { public final void testInitializationStateUponSubmissionFailure() throws Exception {
ShardInfo shardInfo = new ShardInfo("s-0-0", "testToken", null); ShardInfo shardInfo = new ShardInfo("s-0-0", "testToken", null, ExtendedSequenceNumber.TRIM_HORIZON);
ICheckpoint checkpoint = mock(ICheckpoint.class); ICheckpoint checkpoint = mock(ICheckpoint.class);
ExecutorService spyExecutorService = spy(executorService); ExecutorService spyExecutorService = spy(executorService);
@ -189,7 +189,7 @@ public class ShardConsumerTest {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
public final void testRecordProcessorThrowable() throws Exception { public final void testRecordProcessorThrowable() throws Exception {
ShardInfo shardInfo = new ShardInfo("s-0-0", "testToken", null); ShardInfo shardInfo = new ShardInfo("s-0-0", "testToken", null, ExtendedSequenceNumber.TRIM_HORIZON);
ICheckpoint checkpoint = mock(ICheckpoint.class); ICheckpoint checkpoint = mock(ICheckpoint.class);
IRecordProcessor processor = mock(IRecordProcessor.class); IRecordProcessor processor = mock(IRecordProcessor.class);
IKinesisProxy streamProxy = mock(IKinesisProxy.class); IKinesisProxy streamProxy = mock(IKinesisProxy.class);
@ -289,7 +289,7 @@ public class ShardConsumerTest {
callProcessRecordsForEmptyRecordList, callProcessRecordsForEmptyRecordList,
skipCheckpointValidationValue, INITIAL_POSITION_LATEST); skipCheckpointValidationValue, INITIAL_POSITION_LATEST);
ShardInfo shardInfo = new ShardInfo(streamShardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(streamShardId, testConcurrencyToken, null, null);
ShardConsumer consumer = ShardConsumer consumer =
new ShardConsumer(shardInfo, new ShardConsumer(shardInfo,
streamConfig, streamConfig,
@ -379,7 +379,7 @@ public class ShardConsumerTest {
skipCheckpointValidationValue, skipCheckpointValidationValue,
atTimestamp); atTimestamp);
ShardInfo shardInfo = new ShardInfo(streamShardId, testConcurrencyToken, null); ShardInfo shardInfo = new ShardInfo(streamShardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ShardConsumer consumer = ShardConsumer consumer =
new ShardConsumer(shardInfo, new ShardConsumer(shardInfo,
streamConfig, streamConfig,

View file

@ -20,11 +20,12 @@ import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import junit.framework.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber;
public class ShardInfoTest { public class ShardInfoTest {
private static final String CONCURRENCY_TOKEN = UUID.randomUUID().toString(); private static final String CONCURRENCY_TOKEN = UUID.randomUUID().toString();
private static final String SHARD_ID = "shardId-test"; private static final String SHARD_ID = "shardId-test";
@ -37,12 +38,12 @@ public class ShardInfoTest {
parentShardIds.add("shard-1"); parentShardIds.add("shard-1");
parentShardIds.add("shard-2"); parentShardIds.add("shard-2");
testShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds); testShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.LATEST);
} }
@Test @Test
public void testPacboyShardInfoEqualsWithSameArgs() { public void testPacboyShardInfoEqualsWithSameArgs() {
ShardInfo equalShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds); ShardInfo equalShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.LATEST);
Assert.assertTrue("Equal should return true for arguments all the same", testShardInfo.equals(equalShardInfo)); Assert.assertTrue("Equal should return true for arguments all the same", testShardInfo.equals(equalShardInfo));
} }
@ -53,18 +54,18 @@ public class ShardInfoTest {
@Test @Test
public void testPacboyShardInfoEqualsForShardId() { public void testPacboyShardInfoEqualsForShardId() {
ShardInfo diffShardInfo = new ShardInfo("shardId-diff", CONCURRENCY_TOKEN, parentShardIds); ShardInfo diffShardInfo = new ShardInfo("shardId-diff", CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.LATEST);
Assert.assertFalse("Equal should return false with different shard id", diffShardInfo.equals(testShardInfo)); Assert.assertFalse("Equal should return false with different shard id", diffShardInfo.equals(testShardInfo));
diffShardInfo = new ShardInfo(null, CONCURRENCY_TOKEN, parentShardIds); diffShardInfo = new ShardInfo(null, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.LATEST);
Assert.assertFalse("Equal should return false with null shard id", diffShardInfo.equals(testShardInfo)); Assert.assertFalse("Equal should return false with null shard id", diffShardInfo.equals(testShardInfo));
} }
@Test @Test
public void testPacboyShardInfoEqualsForfToken() { public void testPacboyShardInfoEqualsForfToken() {
ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, UUID.randomUUID().toString(), parentShardIds); ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, UUID.randomUUID().toString(), parentShardIds, ExtendedSequenceNumber.LATEST);
Assert.assertFalse("Equal should return false with different concurrency token", Assert.assertFalse("Equal should return false with different concurrency token",
diffShardInfo.equals(testShardInfo)); diffShardInfo.equals(testShardInfo));
diffShardInfo = new ShardInfo(SHARD_ID, null, parentShardIds); diffShardInfo = new ShardInfo(SHARD_ID, null, parentShardIds, ExtendedSequenceNumber.LATEST);
Assert.assertFalse("Equal should return false for null concurrency token", diffShardInfo.equals(testShardInfo)); Assert.assertFalse("Equal should return false for null concurrency token", diffShardInfo.equals(testShardInfo));
} }
@ -74,7 +75,7 @@ public class ShardInfoTest {
differentlyOrderedParentShardIds.add("shard-2"); differentlyOrderedParentShardIds.add("shard-2");
differentlyOrderedParentShardIds.add("shard-1"); differentlyOrderedParentShardIds.add("shard-1");
ShardInfo shardInfoWithDifferentlyOrderedParentShardIds = ShardInfo shardInfoWithDifferentlyOrderedParentShardIds =
new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, differentlyOrderedParentShardIds); new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, differentlyOrderedParentShardIds, ExtendedSequenceNumber.LATEST);
Assert.assertTrue("Equal should return true even with parent shard Ids reordered", Assert.assertTrue("Equal should return true even with parent shard Ids reordered",
shardInfoWithDifferentlyOrderedParentShardIds.equals(testShardInfo)); shardInfoWithDifferentlyOrderedParentShardIds.equals(testShardInfo));
} }
@ -84,16 +85,24 @@ public class ShardInfoTest {
Set<String> diffParentIds = new HashSet<>(); Set<String> diffParentIds = new HashSet<>();
diffParentIds.add("shard-3"); diffParentIds.add("shard-3");
diffParentIds.add("shard-4"); diffParentIds.add("shard-4");
ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, diffParentIds); ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, diffParentIds, ExtendedSequenceNumber.LATEST);
Assert.assertFalse("Equal should return false with different parent shard Ids", Assert.assertFalse("Equal should return false with different parent shard Ids",
diffShardInfo.equals(testShardInfo)); diffShardInfo.equals(testShardInfo));
diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, null); diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, null, ExtendedSequenceNumber.LATEST);
Assert.assertFalse("Equal should return false with null parent shard Ids", diffShardInfo.equals(testShardInfo)); Assert.assertFalse("Equal should return false with null parent shard Ids", diffShardInfo.equals(testShardInfo));
} }
@Test
public void testPacboyShardInfoEqualsForCheckpoint() {
ShardInfo diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.SHARD_END);
Assert.assertFalse("Equal should return false with different checkpoint", diffShardInfo.equals(testShardInfo));
diffShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, null);
Assert.assertFalse("Equal should return false with null checkpoint", diffShardInfo.equals(testShardInfo));
}
@Test @Test
public void testPacboyShardInfoSameHashCode() { public void testPacboyShardInfoSameHashCode() {
ShardInfo equalShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds); ShardInfo equalShardInfo = new ShardInfo(SHARD_ID, CONCURRENCY_TOKEN, parentShardIds, ExtendedSequenceNumber.LATEST);
Assert.assertTrue("Shard info objects should have same hashCode for the same arguments", Assert.assertTrue("Shard info objects should have same hashCode for the same arguments",
equalShardInfo.hashCode() == testShardInfo.hashCode()); equalShardInfo.hashCode() == testShardInfo.hashCode());
} }

View file

@ -20,10 +20,9 @@ import static org.mockito.Mockito.when;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import junit.framework.Assert;
import org.junit.After; import org.junit.After;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
@ -50,7 +49,8 @@ public class ShutdownTaskTest {
String defaultShardId = "shardId-0000397840"; String defaultShardId = "shardId-0000397840";
ShardInfo defaultShardInfo = new ShardInfo(defaultShardId, ShardInfo defaultShardInfo = new ShardInfo(defaultShardId,
defaultConcurrencyToken, defaultConcurrencyToken,
defaultParentShardIds); defaultParentShardIds,
ExtendedSequenceNumber.LATEST);
IRecordProcessor defaultRecordProcessor = new TestStreamlet(); IRecordProcessor defaultRecordProcessor = new TestStreamlet();
/** /**

View file

@ -105,6 +105,7 @@ public class WorkerTest {
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST); InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST);
private static final InitialPositionInStreamExtended INITIAL_POSITION_TRIM_HORIZON = private static final InitialPositionInStreamExtended INITIAL_POSITION_TRIM_HORIZON =
InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON); InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON);
private final ShardPrioritization shardPrioritization = new NoOpShardPrioritization();
// CHECKSTYLE:IGNORE AnonInnerLengthCheck FOR NEXT 50 LINES // CHECKSTYLE:IGNORE AnonInnerLengthCheck FOR NEXT 50 LINES
private static final com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY = private static final com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory SAMPLE_RECORD_PROCESSOR_FACTORY =
@ -192,14 +193,15 @@ public class WorkerTest {
execService, execService,
nullMetricsFactory, nullMetricsFactory,
taskBackoffTimeMillis, taskBackoffTimeMillis,
failoverTimeMillis); failoverTimeMillis,
ShardInfo shardInfo = new ShardInfo(dummyKinesisShardId, testConcurrencyToken, null); shardPrioritization);
ShardInfo shardInfo = new ShardInfo(dummyKinesisShardId, testConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ShardConsumer consumer = worker.createOrGetShardConsumer(shardInfo, streamletFactory); ShardConsumer consumer = worker.createOrGetShardConsumer(shardInfo, streamletFactory);
Assert.assertNotNull(consumer); Assert.assertNotNull(consumer);
ShardConsumer consumer2 = worker.createOrGetShardConsumer(shardInfo, streamletFactory); ShardConsumer consumer2 = worker.createOrGetShardConsumer(shardInfo, streamletFactory);
Assert.assertSame(consumer, consumer2); Assert.assertSame(consumer, consumer2);
ShardInfo shardInfoWithSameShardIdButDifferentConcurrencyToken = ShardInfo shardInfoWithSameShardIdButDifferentConcurrencyToken =
new ShardInfo(dummyKinesisShardId, anotherConcurrencyToken, null); new ShardInfo(dummyKinesisShardId, anotherConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ShardConsumer consumer3 = ShardConsumer consumer3 =
worker.createOrGetShardConsumer(shardInfoWithSameShardIdButDifferentConcurrencyToken, streamletFactory); worker.createOrGetShardConsumer(shardInfoWithSameShardIdButDifferentConcurrencyToken, streamletFactory);
Assert.assertNotNull(consumer3); Assert.assertNotNull(consumer3);
@ -241,12 +243,13 @@ public class WorkerTest {
execService, execService,
nullMetricsFactory, nullMetricsFactory,
taskBackoffTimeMillis, taskBackoffTimeMillis,
failoverTimeMillis); failoverTimeMillis,
shardPrioritization);
ShardInfo shardInfo1 = new ShardInfo(dummyKinesisShardId, concurrencyToken, null); ShardInfo shardInfo1 = new ShardInfo(dummyKinesisShardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ShardInfo duplicateOfShardInfo1ButWithAnotherConcurrencyToken = ShardInfo duplicateOfShardInfo1ButWithAnotherConcurrencyToken =
new ShardInfo(dummyKinesisShardId, anotherConcurrencyToken, null); new ShardInfo(dummyKinesisShardId, anotherConcurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ShardInfo shardInfo2 = new ShardInfo(anotherDummyKinesisShardId, concurrencyToken, null); ShardInfo shardInfo2 = new ShardInfo(anotherDummyKinesisShardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ShardConsumer consumerOfShardInfo1 = worker.createOrGetShardConsumer(shardInfo1, streamletFactory); ShardConsumer consumerOfShardInfo1 = worker.createOrGetShardConsumer(shardInfo1, streamletFactory);
ShardConsumer consumerOfDuplicateOfShardInfo1ButWithAnotherConcurrencyToken = ShardConsumer consumerOfDuplicateOfShardInfo1ButWithAnotherConcurrencyToken =
@ -297,7 +300,8 @@ public class WorkerTest {
execService, execService,
nullMetricsFactory, nullMetricsFactory,
taskBackoffTimeMillis, taskBackoffTimeMillis,
failoverTimeMillis); failoverTimeMillis,
shardPrioritization);
worker.run(); worker.run();
Assert.assertTrue(count > 0); Assert.assertTrue(count > 0);
} }
@ -745,7 +749,8 @@ public class WorkerTest {
executorService, executorService,
metricsFactory, metricsFactory,
taskBackoffTimeMillis, taskBackoffTimeMillis,
failoverTimeMillis); failoverTimeMillis,
shardPrioritization);
WorkerThread workerThread = new WorkerThread(worker); WorkerThread workerThread = new WorkerThread(worker);
workerThread.start(); workerThread.start();