Better Handling of Describe Stream Throttling

Improve the handling of describe stream throttling by no longer
triggering a null pointer exception when all requests are throttled.
Also store the last position reached, and always restart from
there.
This commit is contained in:
Pfifer, Justin 2017-03-21 08:04:51 -07:00 committed by Justin Pfifer
parent 92e8c28995
commit e45f59c73b
2 changed files with 195 additions and 22 deletions

View file

@ -23,6 +23,7 @@ import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import lombok.Data;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -61,6 +62,7 @@ public class KinesisProxy implements IKinesisProxyExtended {
private AmazonKinesis client; private AmazonKinesis client;
private AWSCredentialsProvider credentialsProvider; private AWSCredentialsProvider credentialsProvider;
private AtomicReference<List<Shard>> listOfShardsSinceLastGet = new AtomicReference<>(); private AtomicReference<List<Shard>> listOfShardsSinceLastGet = new AtomicReference<>();
private ShardIterationState shardIterationState = null;
private final String streamName; private final String streamName;
@ -169,9 +171,12 @@ public class KinesisProxy implements IKinesisProxyExtended {
describeStreamRequest.setStreamName(streamName); describeStreamRequest.setStreamName(streamName);
describeStreamRequest.setExclusiveStartShardId(startShardId); describeStreamRequest.setExclusiveStartShardId(startShardId);
DescribeStreamResult response = null; DescribeStreamResult response = null;
LimitExceededException lastException = null;
int remainingRetryTimes = this.maxDescribeStreamRetryAttempts; int remainingRetryTimes = this.maxDescribeStreamRetryAttempts;
// Call DescribeStream, with backoff and retries (if we get LimitExceededException). // Call DescribeStream, with backoff and retries (if we get LimitExceededException).
while ((remainingRetryTimes >= 0) && (response == null)) { while (response == null) {
try { try {
response = client.describeStream(describeStreamRequest); response = client.describeStream(describeStreamRequest);
} catch (LimitExceededException le) { } catch (LimitExceededException le) {
@ -182,8 +187,15 @@ public class KinesisProxy implements IKinesisProxyExtended {
} catch (InterruptedException ie) { } catch (InterruptedException ie) {
LOG.debug("Stream " + streamName + " : Sleep was interrupted ", ie); LOG.debug("Stream " + streamName + " : Sleep was interrupted ", ie);
} }
lastException = le;
} }
remainingRetryTimes--; remainingRetryTimes--;
if (remainingRetryTimes <= 0 && response == null) {
if (lastException != null) {
throw lastException;
}
throw new IllegalStateException("Received null from DescribeStream call.");
}
} }
if (StreamStatus.ACTIVE.toString().equals(response.getStreamDescription().getStreamStatus()) if (StreamStatus.ACTIVE.toString().equals(response.getStreamDescription().getStreamStatus())
@ -220,14 +232,15 @@ public class KinesisProxy implements IKinesisProxyExtended {
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public List<Shard> getShardList() { public synchronized List<Shard> getShardList() {
List<Shard> result = new ArrayList<Shard>();
DescribeStreamResult response = null; DescribeStreamResult response;
String lastShardId = null; if (shardIterationState == null) {
shardIterationState = new ShardIterationState();
}
do { do {
response = getStreamInfo(lastShardId); response = getStreamInfo(shardIterationState.getLastShardId());
if (response == null) { if (response == null) {
/* /*
@ -236,13 +249,12 @@ public class KinesisProxy implements IKinesisProxyExtended {
*/ */
return null; return null;
} else { } else {
List<Shard> shards = response.getStreamDescription().getShards(); shardIterationState.update(response.getStreamDescription().getShards());
result.addAll(shards);
lastShardId = shards.get(shards.size() - 1).getShardId();
} }
} while (response.getStreamDescription().isHasMoreShards()); } while (response.getStreamDescription().isHasMoreShards());
this.listOfShardsSinceLastGet.set(result); this.listOfShardsSinceLastGet.set(shardIterationState.getCollected());
return result;
return shardIterationState.getAndReset();
} }
/** /**
@ -344,4 +356,33 @@ public class KinesisProxy implements IKinesisProxyExtended {
return response; return response;
} }
@Data
static class ShardIterationState {
private List<Shard> collected;
private String lastShardId;
public ShardIterationState() {
collected = new ArrayList<>();
}
public void update(List<Shard> shards) {
if (shards == null || shards.isEmpty()) {
return;
}
collected.addAll(shards);
Shard lastShard = shards.get(shards.size() - 1);
if (lastShardId == null || lastShardId.compareTo(lastShard.getShardId()) < 0) {
lastShardId = lastShard.getShardId();
}
}
public List<Shard> getAndReset() {
List<Shard> result = collected;
collected = new ArrayList<>();
lastShardId = null;
return result;
}
}
} }

View file

@ -5,19 +5,25 @@ import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasProperty; import static org.hamcrest.Matchers.hasProperty;
import static org.hamcrest.Matchers.isA; import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.argThat;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import com.amazonaws.AmazonServiceException; import org.hamcrest.Description;
import org.hamcrest.TypeSafeDiagnosingMatcher;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -25,6 +31,7 @@ import org.mockito.ArgumentMatcher;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.kinesis.AmazonKinesisClient; import com.amazonaws.services.kinesis.AmazonKinesisClient;
import com.amazonaws.services.kinesis.model.DescribeStreamRequest; import com.amazonaws.services.kinesis.model.DescribeStreamRequest;
@ -37,13 +44,11 @@ import com.amazonaws.services.kinesis.model.ShardIteratorType;
import com.amazonaws.services.kinesis.model.StreamDescription; import com.amazonaws.services.kinesis.model.StreamDescription;
import com.amazonaws.services.kinesis.model.StreamStatus; import com.amazonaws.services.kinesis.model.StreamStatus;
import junit.framework.Assert;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class KinesisProxyTest { public class KinesisProxyTest {
private static final String TEST_STRING = "TestString"; private static final String TEST_STRING = "TestString";
private static final long BACKOFF_TIME = 10L; private static final long BACKOFF_TIME = 10L;
private static final int RETRY_TIMES = 50; private static final int RETRY_TIMES = 3;
@Mock @Mock
private AmazonKinesisClient mockClient; private AmazonKinesisClient mockClient;
@ -51,6 +56,13 @@ public class KinesisProxyTest {
private AWSCredentialsProvider mockCredentialsProvider; private AWSCredentialsProvider mockCredentialsProvider;
@Mock @Mock
private GetShardIteratorResult shardIteratorResult; private GetShardIteratorResult shardIteratorResult;
@Mock
private DescribeStreamResult describeStreamResult;
@Mock
private StreamDescription streamDescription;
@Mock
private Shard shard;
private KinesisProxy proxy; private KinesisProxy proxy;
// Test shards for verifying. // Test shards for verifying.
@ -83,10 +95,10 @@ public class KinesisProxyTest {
DescribeStreamResult responseFinal = createGetStreamInfoResponse(shards.subList(2, shards.size()), false); DescribeStreamResult responseFinal = createGetStreamInfoResponse(shards.subList(2, shards.size()), false);
doReturn(responseWithMoreData).when(mockClient).describeStream(argThat(new IsRequestWithStartShardId(null))); doReturn(responseWithMoreData).when(mockClient).describeStream(argThat(new IsRequestWithStartShardId(null)));
doReturn(responseFinal).when(mockClient) doReturn(responseFinal).when(mockClient)
.describeStream(argThat(new IsRequestWithStartShardId(shards.get(1).getShardId()))); .describeStream(argThat(new OldIsRequestWithStartShardId(shards.get(1).getShardId())));
Set<String> resultShardIdSets = proxy.getAllShardIds(); Set<String> resultShardIdSets = proxy.getAllShardIds();
Assert.assertTrue("Result set should equal to Test set", shardIdSet.equals(resultShardIdSets)); assertThat("Result set should equal to Test set", shardIdSet, equalTo(resultShardIdSets));
} }
@Test @Test
@ -96,10 +108,10 @@ public class KinesisProxyTest {
// Second call describeStream returning shards list. // Second call describeStream returning shards list.
DescribeStreamResult response = createGetStreamInfoResponse(shards, false); DescribeStreamResult response = createGetStreamInfoResponse(shards, false);
doThrow(new LimitExceededException("Test Exception")).doReturn(response).when(mockClient) doThrow(new LimitExceededException("Test Exception")).doReturn(response).when(mockClient)
.describeStream(argThat(new IsRequestWithStartShardId(null))); .describeStream(argThat(new OldIsRequestWithStartShardId(null)));
Set<String> resultShardIdSet = proxy.getAllShardIds(); Set<String> resultShardIdSet = proxy.getAllShardIds();
Assert.assertTrue("Result set should equal to Test set", shardIdSet.equals(resultShardIdSet)); assertThat("Result set should equal to Test set", shardIdSet, equalTo(resultShardIdSet));
} }
@Test @Test
@ -132,6 +144,88 @@ public class KinesisProxyTest {
.and(hasProperty("shardIteratorType", nullValue(String.class))))); .and(hasProperty("shardIteratorType", nullValue(String.class)))));
} }
@Test(expected = AmazonServiceException.class)
public void testGetStreamInfoFails() throws Exception {
when(mockClient.describeStream(any(DescribeStreamRequest.class))).thenThrow(new AmazonServiceException("Test"));
proxy.getShardList();
verify(mockClient).describeStream(any(DescribeStreamRequest.class));
}
@Test
public void testGetStreamInfoThrottledOnce() throws Exception {
when(mockClient.describeStream(any(DescribeStreamRequest.class))).thenThrow(new LimitExceededException("Test"))
.thenReturn(describeStreamResult);
when(describeStreamResult.getStreamDescription()).thenReturn(streamDescription);
when(streamDescription.getHasMoreShards()).thenReturn(false);
when(streamDescription.getStreamStatus()).thenReturn(StreamStatus.ACTIVE.name());
List<Shard> expectedShards = Collections.singletonList(shard);
when(streamDescription.getShards()).thenReturn(expectedShards);
List<Shard> actualShards = proxy.getShardList();
assertThat(actualShards, equalTo(expectedShards));
verify(mockClient, times(2)).describeStream(any(DescribeStreamRequest.class));
verify(describeStreamResult, times(3)).getStreamDescription();
verify(streamDescription).getStreamStatus();
verify(streamDescription).isHasMoreShards();
}
@Test(expected = LimitExceededException.class)
public void testGetStreamInfoThrottledAll() throws Exception {
when(mockClient.describeStream(any(DescribeStreamRequest.class))).thenThrow(new LimitExceededException("Test"));
proxy.getShardList();
}
@Test
public void testGetStreamInfoStoresOffset() throws Exception {
when(describeStreamResult.getStreamDescription()).thenReturn(streamDescription);
when(streamDescription.getStreamStatus()).thenReturn(StreamStatus.ACTIVE.name());
Shard shard1 = mock(Shard.class);
Shard shard2 = mock(Shard.class);
Shard shard3 = mock(Shard.class);
List<Shard> shardList1 = Collections.singletonList(shard1);
List<Shard> shardList2 = Collections.singletonList(shard2);
List<Shard> shardList3 = Collections.singletonList(shard3);
String shardId1 = "ShardId-0001";
String shardId2 = "ShardId-0002";
String shardId3 = "ShardId-0003";
when(shard1.getShardId()).thenReturn(shardId1);
when(shard2.getShardId()).thenReturn(shardId2);
when(shard3.getShardId()).thenReturn(shardId3);
when(streamDescription.getShards()).thenReturn(shardList1).thenReturn(shardList2).thenReturn(shardList3);
when(streamDescription.isHasMoreShards()).thenReturn(true, true, false);
when(mockClient.describeStream(argThat(describeWithoutShardId()))).thenReturn(describeStreamResult);
when(mockClient.describeStream(argThat(describeWithShardId(shardId1))))
.thenThrow(new LimitExceededException("1"), new LimitExceededException("2"),
new LimitExceededException("3"))
.thenReturn(describeStreamResult);
when(mockClient.describeStream(argThat(describeWithShardId(shardId2)))).thenReturn(describeStreamResult);
boolean limitExceeded = false;
try {
proxy.getShardList();
} catch (LimitExceededException le) {
limitExceeded = true;
}
assertThat(limitExceeded, equalTo(true));
List<Shard> actualShards = proxy.getShardList();
List<Shard> expectedShards = Arrays.asList(shard1, shard2, shard3);
assertThat(actualShards, equalTo(expectedShards));
verify(mockClient).describeStream(argThat(describeWithoutShardId()));
verify(mockClient, times(4)).describeStream(argThat(describeWithShardId(shardId1)));
verify(mockClient).describeStream(argThat(describeWithShardId(shardId2)));
}
private DescribeStreamResult createGetStreamInfoResponse(List<Shard> shards1, boolean isHasMoreShards) { private DescribeStreamResult createGetStreamInfoResponse(List<Shard> shards1, boolean isHasMoreShards) {
// Create stream description // Create stream description
StreamDescription description = new StreamDescription(); StreamDescription description = new StreamDescription();
@ -145,14 +239,52 @@ public class KinesisProxyTest {
return response; return response;
} }
private IsRequestWithStartShardId describeWithoutShardId() {
return describeWithShardId(null);
}
private IsRequestWithStartShardId describeWithShardId(String shardId) {
return new IsRequestWithStartShardId(shardId);
}
// Matcher for testing describe stream request with specific start shard ID. // Matcher for testing describe stream request with specific start shard ID.
private static class IsRequestWithStartShardId extends ArgumentMatcher<DescribeStreamRequest> { private static class IsRequestWithStartShardId extends TypeSafeDiagnosingMatcher<DescribeStreamRequest> {
private final String shardId; private final String shardId;
public IsRequestWithStartShardId(String shardId) { public IsRequestWithStartShardId(String shardId) {
this.shardId = shardId; this.shardId = shardId;
} }
@Override
protected boolean matchesSafely(DescribeStreamRequest item, Description mismatchDescription) {
if (shardId == null) {
if (item.getExclusiveStartShardId() != null) {
mismatchDescription.appendText("Expected starting shard id of null, but was ")
.appendValue(item.getExclusiveStartShardId());
return false;
}
} else if (!shardId.equals(item.getExclusiveStartShardId())) {
mismatchDescription.appendValue(shardId).appendText(" doesn't match expected ")
.appendValue(item.getExclusiveStartShardId());
return false;
}
return true;
}
@Override
public void describeTo(Description description) {
description.appendText("A DescribeStreamRequest with a starting shard if of ").appendValue(shardId);
}
}
private static class OldIsRequestWithStartShardId extends ArgumentMatcher<DescribeStreamRequest> {
private final String shardId;
public OldIsRequestWithStartShardId(String shardId) {
this.shardId = shardId;
}
@Override @Override
public boolean matches(Object request) { public boolean matches(Object request) {
String startShardId = ((DescribeStreamRequest) request).getExclusiveStartShardId(); String startShardId = ((DescribeStreamRequest) request).getExclusiveStartShardId();