From e45f59c73b2da5e18a6e5d69cd2ddaf03f834a73 Mon Sep 17 00:00:00 2001 From: "Pfifer, Justin" Date: Tue, 21 Mar 2017 08:04:51 -0700 Subject: [PATCH] Better Handling of Describe Stream Throttling Improve the handling of describe stream throttling by no longer triggering a null pointer exception when all requests are throttled. Also store the last position reached, and always restart from there. --- .../clientlibrary/proxies/KinesisProxy.java | 65 ++++++-- .../proxies/KinesisProxyTest.java | 152 ++++++++++++++++-- 2 files changed, 195 insertions(+), 22 deletions(-) diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxy.java index de330dc9..1e6fb1df 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxy.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxy.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; +import lombok.Data; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -61,6 +62,7 @@ public class KinesisProxy implements IKinesisProxyExtended { private AmazonKinesis client; private AWSCredentialsProvider credentialsProvider; private AtomicReference> listOfShardsSinceLastGet = new AtomicReference<>(); + private ShardIterationState shardIterationState = null; private final String streamName; @@ -163,15 +165,18 @@ public class KinesisProxy implements IKinesisProxyExtended { */ @Override public DescribeStreamResult getStreamInfo(String startShardId) - throws ResourceNotFoundException, LimitExceededException { + throws ResourceNotFoundException, LimitExceededException { final DescribeStreamRequest describeStreamRequest = new DescribeStreamRequest(); describeStreamRequest.setRequestCredentials(credentialsProvider.getCredentials()); describeStreamRequest.setStreamName(streamName); describeStreamRequest.setExclusiveStartShardId(startShardId); DescribeStreamResult response = null; + + LimitExceededException lastException = null; + int remainingRetryTimes = this.maxDescribeStreamRetryAttempts; // Call DescribeStream, with backoff and retries (if we get LimitExceededException). - while ((remainingRetryTimes >= 0) && (response == null)) { + while (response == null) { try { response = client.describeStream(describeStreamRequest); } catch (LimitExceededException le) { @@ -182,8 +187,15 @@ public class KinesisProxy implements IKinesisProxyExtended { } catch (InterruptedException ie) { LOG.debug("Stream " + streamName + " : Sleep was interrupted ", ie); } + lastException = le; } remainingRetryTimes--; + if (remainingRetryTimes <= 0 && response == null) { + if (lastException != null) { + throw lastException; + } + throw new IllegalStateException("Received null from DescribeStream call."); + } } if (StreamStatus.ACTIVE.toString().equals(response.getStreamDescription().getStreamStatus()) @@ -220,14 +232,15 @@ public class KinesisProxy implements IKinesisProxyExtended { * {@inheritDoc} */ @Override - public List getShardList() { - List result = new ArrayList(); + public synchronized List getShardList() { - DescribeStreamResult response = null; - String lastShardId = null; + DescribeStreamResult response; + if (shardIterationState == null) { + shardIterationState = new ShardIterationState(); + } do { - response = getStreamInfo(lastShardId); + response = getStreamInfo(shardIterationState.getLastShardId()); if (response == null) { /* @@ -236,13 +249,12 @@ public class KinesisProxy implements IKinesisProxyExtended { */ return null; } else { - List shards = response.getStreamDescription().getShards(); - result.addAll(shards); - lastShardId = shards.get(shards.size() - 1).getShardId(); + shardIterationState.update(response.getStreamDescription().getShards()); } } while (response.getStreamDescription().isHasMoreShards()); - this.listOfShardsSinceLastGet.set(result); - return result; + this.listOfShardsSinceLastGet.set(shardIterationState.getCollected()); + + return shardIterationState.getAndReset(); } /** @@ -344,4 +356,33 @@ public class KinesisProxy implements IKinesisProxyExtended { return response; } + @Data + static class ShardIterationState { + + private List collected; + private String lastShardId; + + public ShardIterationState() { + collected = new ArrayList<>(); + } + + public void update(List shards) { + if (shards == null || shards.isEmpty()) { + return; + } + collected.addAll(shards); + Shard lastShard = shards.get(shards.size() - 1); + if (lastShardId == null || lastShardId.compareTo(lastShard.getShardId()) < 0) { + lastShardId = lastShard.getShardId(); + } + } + + public List getAndReset() { + List result = collected; + collected = new ArrayList<>(); + lastShardId = null; + return result; + } + } + } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxyTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxyTest.java index 2c1107b2..db0e3d0c 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxyTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxyTest.java @@ -5,19 +5,25 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasProperty; import static org.hamcrest.Matchers.isA; import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.argThat; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; -import com.amazonaws.AmazonServiceException; +import org.hamcrest.Description; +import org.hamcrest.TypeSafeDiagnosingMatcher; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -25,6 +31,7 @@ import org.mockito.ArgumentMatcher; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import com.amazonaws.AmazonServiceException; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.services.kinesis.AmazonKinesisClient; import com.amazonaws.services.kinesis.model.DescribeStreamRequest; @@ -37,13 +44,11 @@ import com.amazonaws.services.kinesis.model.ShardIteratorType; import com.amazonaws.services.kinesis.model.StreamDescription; import com.amazonaws.services.kinesis.model.StreamStatus; -import junit.framework.Assert; - @RunWith(MockitoJUnitRunner.class) public class KinesisProxyTest { private static final String TEST_STRING = "TestString"; private static final long BACKOFF_TIME = 10L; - private static final int RETRY_TIMES = 50; + private static final int RETRY_TIMES = 3; @Mock private AmazonKinesisClient mockClient; @@ -51,6 +56,13 @@ public class KinesisProxyTest { private AWSCredentialsProvider mockCredentialsProvider; @Mock private GetShardIteratorResult shardIteratorResult; + @Mock + private DescribeStreamResult describeStreamResult; + @Mock + private StreamDescription streamDescription; + @Mock + private Shard shard; + private KinesisProxy proxy; // Test shards for verifying. @@ -83,10 +95,10 @@ public class KinesisProxyTest { DescribeStreamResult responseFinal = createGetStreamInfoResponse(shards.subList(2, shards.size()), false); doReturn(responseWithMoreData).when(mockClient).describeStream(argThat(new IsRequestWithStartShardId(null))); doReturn(responseFinal).when(mockClient) - .describeStream(argThat(new IsRequestWithStartShardId(shards.get(1).getShardId()))); + .describeStream(argThat(new OldIsRequestWithStartShardId(shards.get(1).getShardId()))); Set resultShardIdSets = proxy.getAllShardIds(); - Assert.assertTrue("Result set should equal to Test set", shardIdSet.equals(resultShardIdSets)); + assertThat("Result set should equal to Test set", shardIdSet, equalTo(resultShardIdSets)); } @Test @@ -96,10 +108,10 @@ public class KinesisProxyTest { // Second call describeStream returning shards list. DescribeStreamResult response = createGetStreamInfoResponse(shards, false); doThrow(new LimitExceededException("Test Exception")).doReturn(response).when(mockClient) - .describeStream(argThat(new IsRequestWithStartShardId(null))); + .describeStream(argThat(new OldIsRequestWithStartShardId(null))); - Set resultShardIdSet = proxy.getAllShardIds(); - Assert.assertTrue("Result set should equal to Test set", shardIdSet.equals(resultShardIdSet)); + Set resultShardIdSet = proxy.getAllShardIds(); + assertThat("Result set should equal to Test set", shardIdSet, equalTo(resultShardIdSet)); } @Test @@ -132,6 +144,88 @@ public class KinesisProxyTest { .and(hasProperty("shardIteratorType", nullValue(String.class))))); } + @Test(expected = AmazonServiceException.class) + public void testGetStreamInfoFails() throws Exception { + when(mockClient.describeStream(any(DescribeStreamRequest.class))).thenThrow(new AmazonServiceException("Test")); + proxy.getShardList(); + verify(mockClient).describeStream(any(DescribeStreamRequest.class)); + } + + @Test + public void testGetStreamInfoThrottledOnce() throws Exception { + when(mockClient.describeStream(any(DescribeStreamRequest.class))).thenThrow(new LimitExceededException("Test")) + .thenReturn(describeStreamResult); + when(describeStreamResult.getStreamDescription()).thenReturn(streamDescription); + when(streamDescription.getHasMoreShards()).thenReturn(false); + when(streamDescription.getStreamStatus()).thenReturn(StreamStatus.ACTIVE.name()); + List expectedShards = Collections.singletonList(shard); + when(streamDescription.getShards()).thenReturn(expectedShards); + + List actualShards = proxy.getShardList(); + + assertThat(actualShards, equalTo(expectedShards)); + + verify(mockClient, times(2)).describeStream(any(DescribeStreamRequest.class)); + verify(describeStreamResult, times(3)).getStreamDescription(); + verify(streamDescription).getStreamStatus(); + verify(streamDescription).isHasMoreShards(); + } + + @Test(expected = LimitExceededException.class) + public void testGetStreamInfoThrottledAll() throws Exception { + when(mockClient.describeStream(any(DescribeStreamRequest.class))).thenThrow(new LimitExceededException("Test")); + + proxy.getShardList(); + } + + @Test + public void testGetStreamInfoStoresOffset() throws Exception { + when(describeStreamResult.getStreamDescription()).thenReturn(streamDescription); + when(streamDescription.getStreamStatus()).thenReturn(StreamStatus.ACTIVE.name()); + Shard shard1 = mock(Shard.class); + Shard shard2 = mock(Shard.class); + Shard shard3 = mock(Shard.class); + List shardList1 = Collections.singletonList(shard1); + List shardList2 = Collections.singletonList(shard2); + List shardList3 = Collections.singletonList(shard3); + + String shardId1 = "ShardId-0001"; + String shardId2 = "ShardId-0002"; + String shardId3 = "ShardId-0003"; + + when(shard1.getShardId()).thenReturn(shardId1); + when(shard2.getShardId()).thenReturn(shardId2); + when(shard3.getShardId()).thenReturn(shardId3); + + when(streamDescription.getShards()).thenReturn(shardList1).thenReturn(shardList2).thenReturn(shardList3); + when(streamDescription.isHasMoreShards()).thenReturn(true, true, false); + when(mockClient.describeStream(argThat(describeWithoutShardId()))).thenReturn(describeStreamResult); + + when(mockClient.describeStream(argThat(describeWithShardId(shardId1)))) + .thenThrow(new LimitExceededException("1"), new LimitExceededException("2"), + new LimitExceededException("3")) + .thenReturn(describeStreamResult); + + when(mockClient.describeStream(argThat(describeWithShardId(shardId2)))).thenReturn(describeStreamResult); + + boolean limitExceeded = false; + try { + proxy.getShardList(); + } catch (LimitExceededException le) { + limitExceeded = true; + } + assertThat(limitExceeded, equalTo(true)); + List actualShards = proxy.getShardList(); + List expectedShards = Arrays.asList(shard1, shard2, shard3); + + assertThat(actualShards, equalTo(expectedShards)); + + verify(mockClient).describeStream(argThat(describeWithoutShardId())); + verify(mockClient, times(4)).describeStream(argThat(describeWithShardId(shardId1))); + verify(mockClient).describeStream(argThat(describeWithShardId(shardId2))); + + } + private DescribeStreamResult createGetStreamInfoResponse(List shards1, boolean isHasMoreShards) { // Create stream description StreamDescription description = new StreamDescription(); @@ -145,14 +239,52 @@ public class KinesisProxyTest { return response; } + private IsRequestWithStartShardId describeWithoutShardId() { + return describeWithShardId(null); + } + + private IsRequestWithStartShardId describeWithShardId(String shardId) { + return new IsRequestWithStartShardId(shardId); + } + // Matcher for testing describe stream request with specific start shard ID. - private static class IsRequestWithStartShardId extends ArgumentMatcher { + private static class IsRequestWithStartShardId extends TypeSafeDiagnosingMatcher { private final String shardId; public IsRequestWithStartShardId(String shardId) { this.shardId = shardId; } + @Override + protected boolean matchesSafely(DescribeStreamRequest item, Description mismatchDescription) { + if (shardId == null) { + if (item.getExclusiveStartShardId() != null) { + mismatchDescription.appendText("Expected starting shard id of null, but was ") + .appendValue(item.getExclusiveStartShardId()); + return false; + } + } else if (!shardId.equals(item.getExclusiveStartShardId())) { + mismatchDescription.appendValue(shardId).appendText(" doesn't match expected ") + .appendValue(item.getExclusiveStartShardId()); + return false; + } + + return true; + } + + @Override + public void describeTo(Description description) { + description.appendText("A DescribeStreamRequest with a starting shard if of ").appendValue(shardId); + } + } + + private static class OldIsRequestWithStartShardId extends ArgumentMatcher { + private final String shardId; + + public OldIsRequestWithStartShardId(String shardId) { + this.shardId = shardId; + } + @Override public boolean matches(Object request) { String startShardId = ((DescribeStreamRequest) request).getExclusiveStartShardId();