Maintain parent/child shard ordering across shard splits/merges. (#155)
Kinesis allows clients to rely on an invariant that, for a given partition key, the order of records added to the stream will be maintained. IE: given an input `pkey=x,val=1 pkey=x,val=2 pkey=x,val=3`, the values `1,2,3` will be seen in that order when processed by clients, so long as clients are careful. It does so by putting all records for a single partition key into a single shard, then maintaining ordering within that shard.
However, shards can be split and merge, to distribute load better and handle per-shard throughput limits. Kinesis does this currently by (one or many times) splitting a single shard into two or by merging two adjacent shards into one. When this occurs, Kinesis still allows for ordering consistency by detailing shard parent/child relationships within its `listShards` outputs. A split shard A will create children B and C, both with `ParentShardId=A`. A merging of shards A and B into C will create a new shard C with `ParentShardId=A,AdjacentParentShardId=B`. So long as clients fully process all records in parents (including adjacent parents) before processing the new shard, ordering will be maintained.
`kinesis-consumer` currently doesn't do this. Instead, upon the initial (and subsequent) `listShards` call, all visible shards immediately begin processing. Considering this case, where shards split, then merge, and each shard `X` contains a single record `rX`:
```
time ->
B
/ \
A D
\ /
C
```
record `rD` should be processed after both `rB` and `rC` are processed, and both `rB` and `rC` should wait for `rA` to be processed. By starting goroutines immediately, any ordering of `{rA,rB,rC,rD}` might occur within the original code.
This PR utilizes the `AllGroup` as a book-keeper of fully processed shards, with the `Consumer` calling `CloseShard` once it has finished a shard. `AllGroup` doesn't release a shard for processing until its parents have fully been processed, and the consumer just processes the shards it receives as it used to.
This PR created a new `CloseableGroup` interface rather than append to the existing `Group` interface to maintain backwards compatibility in existing code that may already implement the `Group` interface elsewhere. Different `Group` implementations don't get the ordering described above, but the default `Consumer` does.
This commit is contained in:
parent
188bdff278
commit
6720a01733
4 changed files with 296 additions and 16 deletions
69
allgroup.go
69
allgroup.go
|
|
@ -2,6 +2,7 @@ package consumer
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -12,11 +13,12 @@ import (
|
|||
// all shards on a stream
|
||||
func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup {
|
||||
return &AllGroup{
|
||||
ksis: ksis,
|
||||
shards: make(map[string]types.Shard),
|
||||
streamName: streamName,
|
||||
logger: logger,
|
||||
Store: store,
|
||||
ksis: ksis,
|
||||
shards: make(map[string]types.Shard),
|
||||
shardsClosed: make(map[string]chan struct{}),
|
||||
streamName: streamName,
|
||||
logger: logger,
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -29,8 +31,9 @@ type AllGroup struct {
|
|||
logger Logger
|
||||
Store
|
||||
|
||||
shardMu sync.Mutex
|
||||
shards map[string]types.Shard
|
||||
shardMu sync.Mutex
|
||||
shards map[string]types.Shard
|
||||
shardsClosed map[string]chan struct{}
|
||||
}
|
||||
|
||||
// Start is a blocking operation which will loop and attempt to find new
|
||||
|
|
@ -59,6 +62,32 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
|
|||
}
|
||||
}
|
||||
|
||||
func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error {
|
||||
g.shardMu.Lock()
|
||||
defer g.shardMu.Unlock()
|
||||
c, ok := g.shardsClosed[shardID]
|
||||
if !ok {
|
||||
return fmt.Errorf("closing unknown shard ID %q", shardID)
|
||||
}
|
||||
close(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
|
||||
if c == nil {
|
||||
// no channel means we haven't seen this shard in listShards, so it
|
||||
// probably fell off the TRIM_HORIZON, and we can assume it's fully processed.
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-c:
|
||||
// the channel has been processed and closed by the consumer (CloseShard has been called)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// findNewShards pulls the list of shards from the Kinesis API
|
||||
// and uses a local cache to determine if we are already processing
|
||||
// a particular shard.
|
||||
|
|
@ -74,11 +103,35 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
|
|||
return
|
||||
}
|
||||
|
||||
// We do two `for` loops, since we have to set up all the `shardClosed`
|
||||
// channels before we start using any of them. It's highly probable
|
||||
// that Kinesis provides us the shards in dependency order (parents
|
||||
// before children), but it doesn't appear to be a guarantee.
|
||||
for _, shard := range shards {
|
||||
if _, ok := g.shards[*shard.ShardId]; ok {
|
||||
continue
|
||||
}
|
||||
g.shards[*shard.ShardId] = shard
|
||||
shardc <- shard
|
||||
g.shardsClosed[*shard.ShardId] = make(chan struct{})
|
||||
}
|
||||
for _, shard := range shards {
|
||||
shard := shard // Shadow shard, since we use it in goroutine
|
||||
var parent1, parent2 <-chan struct{}
|
||||
if shard.ParentShardId != nil {
|
||||
parent1 = g.shardsClosed[*shard.ParentShardId]
|
||||
}
|
||||
if shard.AdjacentParentShardId != nil {
|
||||
parent2 = g.shardsClosed[*shard.AdjacentParentShardId]
|
||||
}
|
||||
go func() {
|
||||
// Asynchronously wait for all parents of this shard to be processed
|
||||
// before providing it out to our client. Kinesis guarantees that a
|
||||
// given partition key's data will be provided to clients in-order,
|
||||
// but when splits or joins happen, we need to process all parents prior
|
||||
// to processing children or that ordering guarantee is not maintained.
|
||||
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
|
||||
shardc <- shard
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
20
consumer.go
20
consumer.go
|
|
@ -118,13 +118,19 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
wg.Add(1)
|
||||
go func(shardID string) {
|
||||
defer wg.Done()
|
||||
if err := c.ScanShard(ctx, shardID, fn); err != nil {
|
||||
var err error
|
||||
if err = c.ScanShard(ctx, shardID, fn); err != nil {
|
||||
err = fmt.Errorf("shard %s error: %w", shardID, err)
|
||||
} else if closeable, ok := c.group.(CloseableGroup); !ok {
|
||||
// group doesn't allow closure, skip calling CloseShard
|
||||
} else if err = closeable.CloseShard(ctx, shardID); err != nil {
|
||||
err = fmt.Errorf("shard closed CloseableGroup error: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case errc <- fmt.Errorf("shard %s error: %w", shardID, err):
|
||||
// first error to occur
|
||||
cancel()
|
||||
default:
|
||||
// error has already occurred
|
||||
}
|
||||
}
|
||||
}(aws.ToString(shard.ShardId))
|
||||
|
|
@ -138,6 +144,10 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
return <-errc
|
||||
}
|
||||
|
||||
func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScanShard loops over records on a specific shard, calls the callback func
|
||||
// for each record and checkpoints the progress of scan.
|
||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||
|
|
@ -218,9 +228,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
c.logger.Log("[CONSUMER] shard closed:", shardID)
|
||||
|
||||
if c.shardClosedHandler != nil {
|
||||
err := c.shardClosedHandler(c.streamName, shardID)
|
||||
|
||||
if err != nil {
|
||||
if err := c.shardClosedHandler(c.streamName, shardID); err != nil {
|
||||
return fmt.Errorf("shard closed handler error: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
216
consumer_test.go
216
consumer_test.go
|
|
@ -3,8 +3,10 @@ package consumer
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||
|
|
@ -24,6 +26,15 @@ var records = []types.Record{
|
|||
},
|
||||
}
|
||||
|
||||
// Implement logger to wrap testing.T.Log.
|
||||
type testLogger struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (t *testLogger) Log(args ...interface{}) {
|
||||
t.t.Log(args...)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
if _, err := New("myStreamName"); err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
|
|
@ -60,6 +71,7 @@ func TestScan(t *testing.T) {
|
|||
WithClient(client),
|
||||
WithCounter(ctr),
|
||||
WithStore(cp),
|
||||
WithLogger(&testLogger{t}),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
|
|
@ -122,6 +134,7 @@ func TestScanShard(t *testing.T) {
|
|||
WithClient(client),
|
||||
WithCounter(ctr),
|
||||
WithStore(cp),
|
||||
WithLogger(&testLogger{t}),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
|
|
@ -301,7 +314,8 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) {
|
|||
WithClient(client),
|
||||
WithShardClosedHandler(func(streamName, shardID string) error {
|
||||
return fmt.Errorf("closed shard error")
|
||||
}))
|
||||
}),
|
||||
WithLogger(&testLogger{t}))
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
|
@ -335,7 +349,7 @@ func TestScanShard_GetRecordsError(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
c, err := New("myStreamName", WithClient(client))
|
||||
c, err := New("myStreamName", WithClient(client), WithLogger(&testLogger{t}))
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
|
@ -384,3 +398,201 @@ func (fc *fakeCounter) Add(streamName string, count int64) {
|
|||
|
||||
fc.counter += count
|
||||
}
|
||||
|
||||
func TestScan_PreviousParentsBeforeTrimHorizon(t *testing.T) {
|
||||
client := &kinesisClientMock{
|
||||
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return &kinesis.GetShardIteratorOutput{
|
||||
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||
}, nil
|
||||
},
|
||||
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: records,
|
||||
}, nil
|
||||
},
|
||||
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||
return &kinesis.ListShardsOutput{
|
||||
Shards: []types.Shard{
|
||||
{
|
||||
ShardId: aws.String("myShard"),
|
||||
ParentShardId: aws.String("myOldParent"),
|
||||
AdjacentParentShardId: aws.String("myOldAdjacentParent"),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
var (
|
||||
cp = store.New()
|
||||
ctr = &fakeCounter{}
|
||||
)
|
||||
|
||||
c, err := New("myStreamName",
|
||||
WithClient(client),
|
||||
WithCounter(ctr),
|
||||
WithStore(cp),
|
||||
WithLogger(&testLogger{t}),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
res string
|
||||
)
|
||||
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
|
||||
if string(r.Data) == "lastData" {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.Scan(ctx, fn); err != nil {
|
||||
t.Errorf("scan returned unexpected error %v", err)
|
||||
}
|
||||
|
||||
if res != "firstDatalastData" {
|
||||
t.Errorf("callback error expected %s, got %s", "firstDatalastData", res)
|
||||
}
|
||||
|
||||
if val := ctr.Get(); val != 2 {
|
||||
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||
}
|
||||
|
||||
val, err := cp.GetCheckpoint("myStreamName", "myShard")
|
||||
if err != nil && val != "lastSeqNum" {
|
||||
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScan_ParentChildOrdering(t *testing.T) {
|
||||
// We create a set of shards where shard1 split into (shard2,shard3), then (shard2,shard3) merged into shard4.
|
||||
client := &kinesisClientMock{
|
||||
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||
return &kinesis.GetShardIteratorOutput{
|
||||
ShardIterator: aws.String(*params.ShardId + "iter"),
|
||||
}, nil
|
||||
},
|
||||
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
|
||||
switch *params.ShardIterator {
|
||||
case "shard1iter":
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: []types.Record{
|
||||
{
|
||||
Data: []byte("shard1data"),
|
||||
SequenceNumber: aws.String("shard1num"),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
case "shard2iter":
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: []types.Record{},
|
||||
}, nil
|
||||
case "shard3iter":
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: []types.Record{
|
||||
{
|
||||
Data: []byte("shard3data"),
|
||||
SequenceNumber: aws.String("shard3num"),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
case "shard4iter":
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: []types.Record{
|
||||
{
|
||||
Data: []byte("shard4data"),
|
||||
SequenceNumber: aws.String("shard4num"),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
panic("got unexpected iterator")
|
||||
}
|
||||
},
|
||||
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||
// Intentionally misorder these to test resiliance to ordering issues from ListShards.
|
||||
return &kinesis.ListShardsOutput{
|
||||
Shards: []types.Shard{
|
||||
{
|
||||
ShardId: aws.String("shard3"),
|
||||
ParentShardId: aws.String("shard1"),
|
||||
},
|
||||
{
|
||||
ShardId: aws.String("shard1"),
|
||||
ParentShardId: aws.String("shard0"), // not otherwise referenced, parent ordering should ignore this
|
||||
},
|
||||
{
|
||||
ShardId: aws.String("shard4"),
|
||||
ParentShardId: aws.String("shard2"),
|
||||
AdjacentParentShardId: aws.String("shard3"),
|
||||
},
|
||||
{
|
||||
ShardId: aws.String("shard2"),
|
||||
ParentShardId: aws.String("shard1"),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
var (
|
||||
cp = store.New()
|
||||
ctr = &fakeCounter{}
|
||||
)
|
||||
|
||||
c, err := New("myStreamName",
|
||||
WithClient(client),
|
||||
WithCounter(ctr),
|
||||
WithStore(cp),
|
||||
WithLogger(&testLogger{t}),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
res string
|
||||
)
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond)
|
||||
|
||||
if string(r.Data) == "shard4data" {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.Scan(ctx, fn); err != nil {
|
||||
t.Errorf("scan returned unexpected error %v", err)
|
||||
}
|
||||
|
||||
if want := "shard1datashard3datashard4data"; res != want {
|
||||
t.Errorf("callback error expected %s, got %s", want, res)
|
||||
}
|
||||
|
||||
if val := ctr.Get(); val != 3 {
|
||||
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||
}
|
||||
|
||||
val, err := cp.GetCheckpoint("myStreamName", "shard4data")
|
||||
if err != nil && val != "shard4num" {
|
||||
t.Errorf("checkout error expected %s, got %s", "shard4num", val)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
7
group.go
7
group.go
|
|
@ -12,3 +12,10 @@ type Group interface {
|
|||
GetCheckpoint(streamName, shardID string) (string, error)
|
||||
SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
||||
}
|
||||
|
||||
type CloseableGroup interface {
|
||||
Group
|
||||
// Allows shard processors to tell the group when the shard has been
|
||||
// fully processed. Should be called only once per shardID.
|
||||
CloseShard(ctx context.Context, shardID string) error
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue