Removing cached shard progress, adding guardrails for duplicate shard responses. (#811)

Co-authored-by: Joshua Kim <kimjos@amazon.com>
This commit is contained in:
Joshua Kim 2021-05-03 13:50:54 -07:00 committed by GitHub
parent f38dd18ed1
commit f2b9006a98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 85 deletions

View file

@ -22,6 +22,7 @@ import java.util.ArrayList;
import java.util.Date;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -29,6 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.amazonaws.services.kinesis.clientlibrary.utils.RequestUtil;
import com.amazonaws.services.kinesis.model.ShardFilter;
import com.amazonaws.util.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@ -59,7 +61,6 @@ import com.amazonaws.services.kinesis.model.ShardIteratorType;
import com.amazonaws.services.kinesis.model.StreamStatus;
import lombok.AccessLevel;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
@ -82,8 +83,6 @@ public class KinesisProxy implements IKinesisProxyExtended {
private AmazonKinesis client;
private AWSCredentialsProvider credentialsProvider;
private ShardIterationState shardIterationState = null;
@Setter(AccessLevel.PACKAGE)
private volatile Map<String, Shard> cachedShardMap = null;
@Setter(AccessLevel.PACKAGE)
@ -442,10 +441,8 @@ public class KinesisProxy implements IKinesisProxyExtended {
*/
@Override
public synchronized List<Shard> getShardListWithFilter(ShardFilter shardFilter) {
if (shardIterationState == null) {
shardIterationState = new ShardIterationState();
}
final List<Shard> shards = new ArrayList<>();
final List<String> requestIds = new ArrayList<>();
if (isKinesisClient) {
ListShardsResult result;
String nextToken = null;
@ -460,16 +457,18 @@ public class KinesisProxy implements IKinesisProxyExtended {
*/
return null;
} else {
shardIterationState.update(result.getShards());
shards.addAll(result.getShards());
requestIds.add(RequestUtil.requestId(result));
nextToken = result.getNextToken();
}
} while (StringUtils.isNotEmpty(result.getNextToken()));
} else {
DescribeStreamResult response;
String lastShardId = null;
do {
response = getStreamInfo(shardIterationState.getLastShardId());
response = getStreamInfo(lastShardId);
if (response == null) {
/*
@ -478,16 +477,26 @@ public class KinesisProxy implements IKinesisProxyExtended {
*/
return null;
} else {
shardIterationState.update(response.getStreamDescription().getShards());
final List<Shard> pageOfShards = response.getStreamDescription().getShards();
shards.addAll(pageOfShards);
requestIds.add(RequestUtil.requestId(response));
final Shard lastShard = pageOfShards.get(pageOfShards.size() - 1);
if (lastShardId == null || lastShardId.compareTo(lastShard.getShardId()) < 0) {
lastShardId = lastShard.getShardId();
}
}
} while (response.getStreamDescription().isHasMoreShards());
}
List<Shard> shards = shardIterationState.getShards();
this.cachedShardMap = shards.stream().collect(Collectors.toMap(Shard::getShardId, Function.identity()));
final List<Shard> dedupedShards = new ArrayList<>(new LinkedHashSet<>(shards));
if (dedupedShards.size() < shards.size()) {
LOG.warn("Found duplicate shards in response when sync'ing from Kinesis. " +
"Request ids - " + requestIds + ". Response - " + shards);
}
this.cachedShardMap = dedupedShards.stream().collect(Collectors.toMap(Shard::getShardId, Function.identity()));
this.lastCacheUpdateTime = Instant.now();
shardIterationState = new ShardIterationState();
return shards;
return dedupedShards;
}
/**
@ -617,27 +626,4 @@ public class KinesisProxy implements IKinesisProxyExtended {
final PutRecordResult response = client.putRecord(putRecordRequest);
return response;
}
@Data
static class ShardIterationState {
private List<Shard> shards;
private String lastShardId;
public ShardIterationState() {
shards = new ArrayList<>();
}
public void update(List<Shard> shards) {
if (shards == null || shards.isEmpty()) {
return;
}
this.shards.addAll(shards);
Shard lastShard = shards.get(shards.size() - 1);
if (lastShardId == null || lastShardId.compareTo(lastShard.getShardId()) < 0) {
lastShardId = lastShard.getShardId();
}
}
}
}

View file

@ -0,0 +1,24 @@
package com.amazonaws.services.kinesis.clientlibrary.utils;
import com.amazonaws.AmazonWebServiceResult;
/**
* Helper class to parse metadata from AWS requests.
*/
public class RequestUtil {
private static final String DEFAULT_REQUEST_ID = "NONE";
/**
* Get the requestId associated with a request.
*
* @param result
* @return the requestId for a request, or "NONE" if one is not available.
*/
public static String requestId(AmazonWebServiceResult result) {
if (result == null || result.getSdkResponseMetadata() == null || result.getSdkResponseMetadata().getRequestId() == null) {
return DEFAULT_REQUEST_ID;
}
return result.getSdkResponseMetadata().getRequestId();
}
}

View file

@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat;
import static org.mockito.Mockito.doReturn;
@ -92,6 +93,7 @@ public class KinesisProxyTest {
private static final String SHARD_4 = "shard-4";
private static final String NOT_CACHED_SHARD = "ShardId-0005";
private static final String NEVER_PRESENT_SHARD = "ShardId-0010";
private static final String REQUEST_ID = "requestId";
@Mock
private AmazonKinesis mockClient;
@ -249,54 +251,6 @@ public class KinesisProxyTest {
ddbProxy.getShardList();
}
@Test
public void testGetStreamInfoStoresOffset() throws Exception {
when(describeStreamResult.getStreamDescription()).thenReturn(streamDescription);
when(streamDescription.getStreamStatus()).thenReturn(StreamStatus.ACTIVE.name());
Shard shard1 = mock(Shard.class);
Shard shard2 = mock(Shard.class);
Shard shard3 = mock(Shard.class);
List<Shard> shardList1 = Collections.singletonList(shard1);
List<Shard> shardList2 = Collections.singletonList(shard2);
List<Shard> shardList3 = Collections.singletonList(shard3);
String shardId1 = "ShardId-0001";
String shardId2 = "ShardId-0002";
String shardId3 = "ShardId-0003";
when(shard1.getShardId()).thenReturn(shardId1);
when(shard2.getShardId()).thenReturn(shardId2);
when(shard3.getShardId()).thenReturn(shardId3);
when(streamDescription.getShards()).thenReturn(shardList1).thenReturn(shardList2).thenReturn(shardList3);
when(streamDescription.isHasMoreShards()).thenReturn(true, true, false);
when(mockDDBStreamClient.describeStream(argThat(describeWithoutShardId()))).thenReturn(describeStreamResult);
when(mockDDBStreamClient.describeStream(argThat(describeWithShardId(shardId1))))
.thenThrow(new LimitExceededException("1"), new LimitExceededException("2"),
new LimitExceededException("3"))
.thenReturn(describeStreamResult);
when(mockDDBStreamClient.describeStream(argThat(describeWithShardId(shardId2)))).thenReturn(describeStreamResult);
boolean limitExceeded = false;
try {
ddbProxy.getShardList();
} catch (LimitExceededException le) {
limitExceeded = true;
}
assertThat(limitExceeded, equalTo(true));
List<Shard> actualShards = ddbProxy.getShardList();
List<Shard> expectedShards = Arrays.asList(shard1, shard2, shard3);
assertThat(actualShards, equalTo(expectedShards));
verify(mockDDBStreamClient).describeStream(argThat(describeWithoutShardId()));
verify(mockDDBStreamClient, times(4)).describeStream(argThat(describeWithShardId(shardId1)));
verify(mockDDBStreamClient).describeStream(argThat(describeWithShardId(shardId2)));
}
@Test
public void testListShardsWithMoreDataAvailable() {
ListShardsResult responseWithMoreData = new ListShardsResult().withShards(shards.subList(0, 2)).withNextToken(NEXT_TOKEN);
@ -483,6 +437,47 @@ public class KinesisProxyTest {
verify(mockClient).listShards(any());
}
/**
* Tests that if we fail halfway through a listShards call, we fail gracefully and subsequent calls are not
* affected by the failure of the first request.
*/
@Test
public void testNoDuplicateShardsInPartialFailure() {
proxy.setCachedShardMap(null);
ListShardsResult firstPage = new ListShardsResult().withShards(shards.subList(0, 2)).withNextToken(NEXT_TOKEN);
ListShardsResult lastPage = new ListShardsResult().withShards(shards.subList(2, shards.size())).withNextToken(null);
when(mockClient.listShards(any()))
.thenReturn(firstPage).thenThrow(new RuntimeException("Failed!"))
.thenReturn(firstPage).thenReturn(lastPage);
try {
proxy.getShardList();
fail("First ListShards call should have failed!");
} catch (Exception e) {
// Do nothing
}
assertEquals(shards, proxy.getShardList());
}
/**
* Tests that if we receive any duplicate shard responses from the service during a shard sync, we dedup the response
* and continue gracefully.
*/
@Test
public void testDuplicateShardResponseDedupedGracefully() {
proxy.setCachedShardMap(null);
List<Shard> duplicateShards = new ArrayList<>(shards);
duplicateShards.addAll(shards);
ListShardsResult pageOfShards = new ListShardsResult().withShards(duplicateShards).withNextToken(null);
when(mockClient.listShards(any())).thenReturn(pageOfShards);
proxy.getShardList();
assertEquals(shards, proxy.getShardList());
}
private void mockListShardsForSingleResponse(List<Shard> shards) {
when(mockClient.listShards(any())).thenReturn(listShardsResult);
when(listShardsResult.getShards()).thenReturn(shards);