diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxy.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxy.java index c1d7f10d..c6b641bb 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxy.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxy.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Date; import java.util.EnumSet; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -29,6 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Collectors; +import com.amazonaws.services.kinesis.clientlibrary.utils.RequestUtil; import com.amazonaws.services.kinesis.model.ShardFilter; import com.amazonaws.util.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -59,7 +61,6 @@ import com.amazonaws.services.kinesis.model.ShardIteratorType; import com.amazonaws.services.kinesis.model.StreamStatus; import lombok.AccessLevel; -import lombok.Data; import lombok.Getter; import lombok.Setter; @@ -82,8 +83,6 @@ public class KinesisProxy implements IKinesisProxyExtended { private AmazonKinesis client; private AWSCredentialsProvider credentialsProvider; - private ShardIterationState shardIterationState = null; - @Setter(AccessLevel.PACKAGE) private volatile Map cachedShardMap = null; @Setter(AccessLevel.PACKAGE) @@ -442,10 +441,8 @@ public class KinesisProxy implements IKinesisProxyExtended { */ @Override public synchronized List getShardListWithFilter(ShardFilter shardFilter) { - if (shardIterationState == null) { - shardIterationState = new ShardIterationState(); - } - + final List shards = new ArrayList<>(); + final List requestIds = new ArrayList<>(); if (isKinesisClient) { ListShardsResult result; String nextToken = null; @@ -460,16 +457,18 @@ public class KinesisProxy implements IKinesisProxyExtended { */ return null; } else { - shardIterationState.update(result.getShards()); + shards.addAll(result.getShards()); + requestIds.add(RequestUtil.requestId(result)); nextToken = result.getNextToken(); } } while (StringUtils.isNotEmpty(result.getNextToken())); } else { DescribeStreamResult response; + String lastShardId = null; do { - response = getStreamInfo(shardIterationState.getLastShardId()); + response = getStreamInfo(lastShardId); if (response == null) { /* @@ -478,16 +477,26 @@ public class KinesisProxy implements IKinesisProxyExtended { */ return null; } else { - shardIterationState.update(response.getStreamDescription().getShards()); + final List pageOfShards = response.getStreamDescription().getShards(); + shards.addAll(pageOfShards); + requestIds.add(RequestUtil.requestId(response)); + + final Shard lastShard = pageOfShards.get(pageOfShards.size() - 1); + if (lastShardId == null || lastShardId.compareTo(lastShard.getShardId()) < 0) { + lastShardId = lastShard.getShardId(); + } } } while (response.getStreamDescription().isHasMoreShards()); } - List shards = shardIterationState.getShards(); - this.cachedShardMap = shards.stream().collect(Collectors.toMap(Shard::getShardId, Function.identity())); + final List dedupedShards = new ArrayList<>(new LinkedHashSet<>(shards)); + if (dedupedShards.size() < shards.size()) { + LOG.warn("Found duplicate shards in response when sync'ing from Kinesis. " + + "Request ids - " + requestIds + ". Response - " + shards); + } + this.cachedShardMap = dedupedShards.stream().collect(Collectors.toMap(Shard::getShardId, Function.identity())); this.lastCacheUpdateTime = Instant.now(); - shardIterationState = new ShardIterationState(); - return shards; + return dedupedShards; } /** @@ -617,27 +626,4 @@ public class KinesisProxy implements IKinesisProxyExtended { final PutRecordResult response = client.putRecord(putRecordRequest); return response; } - - @Data - static class ShardIterationState { - - private List shards; - private String lastShardId; - - public ShardIterationState() { - shards = new ArrayList<>(); - } - - public void update(List shards) { - if (shards == null || shards.isEmpty()) { - return; - } - this.shards.addAll(shards); - Shard lastShard = shards.get(shards.size() - 1); - if (lastShardId == null || lastShardId.compareTo(lastShard.getShardId()) < 0) { - lastShardId = lastShard.getShardId(); - } - } - } - } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/utils/RequestUtil.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/utils/RequestUtil.java new file mode 100644 index 00000000..cac65d45 --- /dev/null +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/utils/RequestUtil.java @@ -0,0 +1,24 @@ +package com.amazonaws.services.kinesis.clientlibrary.utils; + +import com.amazonaws.AmazonWebServiceResult; + +/** + * Helper class to parse metadata from AWS requests. + */ +public class RequestUtil { + private static final String DEFAULT_REQUEST_ID = "NONE"; + + /** + * Get the requestId associated with a request. + * + * @param result + * @return the requestId for a request, or "NONE" if one is not available. + */ + public static String requestId(AmazonWebServiceResult result) { + if (result == null || result.getSdkResponseMetadata() == null || result.getSdkResponseMetadata().getRequestId() == null) { + return DEFAULT_REQUEST_ID; + } + + return result.getSdkResponseMetadata().getRequestId(); + } +} diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxyTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxyTest.java index 75feb19e..76671176 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxyTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/proxies/KinesisProxyTest.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.argThat; import static org.mockito.Mockito.doReturn; @@ -92,6 +93,7 @@ public class KinesisProxyTest { private static final String SHARD_4 = "shard-4"; private static final String NOT_CACHED_SHARD = "ShardId-0005"; private static final String NEVER_PRESENT_SHARD = "ShardId-0010"; + private static final String REQUEST_ID = "requestId"; @Mock private AmazonKinesis mockClient; @@ -249,54 +251,6 @@ public class KinesisProxyTest { ddbProxy.getShardList(); } - @Test - public void testGetStreamInfoStoresOffset() throws Exception { - when(describeStreamResult.getStreamDescription()).thenReturn(streamDescription); - when(streamDescription.getStreamStatus()).thenReturn(StreamStatus.ACTIVE.name()); - Shard shard1 = mock(Shard.class); - Shard shard2 = mock(Shard.class); - Shard shard3 = mock(Shard.class); - List shardList1 = Collections.singletonList(shard1); - List shardList2 = Collections.singletonList(shard2); - List shardList3 = Collections.singletonList(shard3); - - String shardId1 = "ShardId-0001"; - String shardId2 = "ShardId-0002"; - String shardId3 = "ShardId-0003"; - - when(shard1.getShardId()).thenReturn(shardId1); - when(shard2.getShardId()).thenReturn(shardId2); - when(shard3.getShardId()).thenReturn(shardId3); - - when(streamDescription.getShards()).thenReturn(shardList1).thenReturn(shardList2).thenReturn(shardList3); - when(streamDescription.isHasMoreShards()).thenReturn(true, true, false); - when(mockDDBStreamClient.describeStream(argThat(describeWithoutShardId()))).thenReturn(describeStreamResult); - - when(mockDDBStreamClient.describeStream(argThat(describeWithShardId(shardId1)))) - .thenThrow(new LimitExceededException("1"), new LimitExceededException("2"), - new LimitExceededException("3")) - .thenReturn(describeStreamResult); - - when(mockDDBStreamClient.describeStream(argThat(describeWithShardId(shardId2)))).thenReturn(describeStreamResult); - - boolean limitExceeded = false; - try { - ddbProxy.getShardList(); - } catch (LimitExceededException le) { - limitExceeded = true; - } - assertThat(limitExceeded, equalTo(true)); - List actualShards = ddbProxy.getShardList(); - List expectedShards = Arrays.asList(shard1, shard2, shard3); - - assertThat(actualShards, equalTo(expectedShards)); - - verify(mockDDBStreamClient).describeStream(argThat(describeWithoutShardId())); - verify(mockDDBStreamClient, times(4)).describeStream(argThat(describeWithShardId(shardId1))); - verify(mockDDBStreamClient).describeStream(argThat(describeWithShardId(shardId2))); - - } - @Test public void testListShardsWithMoreDataAvailable() { ListShardsResult responseWithMoreData = new ListShardsResult().withShards(shards.subList(0, 2)).withNextToken(NEXT_TOKEN); @@ -483,6 +437,47 @@ public class KinesisProxyTest { verify(mockClient).listShards(any()); } + /** + * Tests that if we fail halfway through a listShards call, we fail gracefully and subsequent calls are not + * affected by the failure of the first request. + */ + @Test + public void testNoDuplicateShardsInPartialFailure() { + proxy.setCachedShardMap(null); + + ListShardsResult firstPage = new ListShardsResult().withShards(shards.subList(0, 2)).withNextToken(NEXT_TOKEN); + ListShardsResult lastPage = new ListShardsResult().withShards(shards.subList(2, shards.size())).withNextToken(null); + + when(mockClient.listShards(any())) + .thenReturn(firstPage).thenThrow(new RuntimeException("Failed!")) + .thenReturn(firstPage).thenReturn(lastPage); + + try { + proxy.getShardList(); + fail("First ListShards call should have failed!"); + } catch (Exception e) { + // Do nothing + } + assertEquals(shards, proxy.getShardList()); + } + + /** + * Tests that if we receive any duplicate shard responses from the service during a shard sync, we dedup the response + * and continue gracefully. + */ + @Test + public void testDuplicateShardResponseDedupedGracefully() { + proxy.setCachedShardMap(null); + List duplicateShards = new ArrayList<>(shards); + duplicateShards.addAll(shards); + ListShardsResult pageOfShards = new ListShardsResult().withShards(duplicateShards).withNextToken(null); + + when(mockClient.listShards(any())).thenReturn(pageOfShards); + + proxy.getShardList(); + assertEquals(shards, proxy.getShardList()); + } + private void mockListShardsForSingleResponse(List shards) { when(mockClient.listShards(any())).thenReturn(listShardsResult); when(listShardsResult.getShards()).thenReturn(shards);