#185 reverts upstream changes

This commit is contained in:
Alex Senger 2024-09-10 17:22:50 +02:00
parent de0c50cc32
commit 58ce4ba9f5
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
4 changed files with 24 additions and 382 deletions

View file

@ -2,7 +2,6 @@ package consumer
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
@ -14,12 +13,11 @@ import (
// all shards on a stream
func NewAllGroup(kinesis kinesisClient, store Store, streamName string, logger *slog.Logger) *AllGroup {
return &AllGroup{
kinesis: kinesis,
shards: make(map[string]types.Shard),
shardsClosed: make(map[string]chan struct{}),
streamName: streamName,
slog: logger,
Store: store,
kinesis: kinesis,
shards: make(map[string]types.Shard),
streamName: streamName,
slog: logger,
Store: store,
}
}
@ -32,14 +30,13 @@ type AllGroup struct {
slog *slog.Logger
Store
shardMu sync.Mutex
shards map[string]types.Shard
shardsClosed map[string]chan struct{}
shardMu sync.Mutex
shards map[string]types.Shard
}
// Start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
// Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. i.e. If we miss a new shard
// while AWS is re-sharding we'll pick it up max 30 seconds later.
@ -52,51 +49,21 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
var ticker = time.NewTicker(30 * time.Second)
for {
err := g.findNewShards(ctx, shardc)
if err != nil {
ticker.Stop()
return err
}
g.findNewShards(ctx, shardc)
select {
case <-ctx.Done():
ticker.Stop()
return nil
return
case <-ticker.C:
}
}
}
func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error {
g.shardMu.Lock()
defer g.shardMu.Unlock()
c, ok := g.shardsClosed[shardID]
if !ok {
return fmt.Errorf("closing unknown shard ID %q", shardID)
}
close(c)
return nil
}
func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
if c == nil {
// no channel means we haven't seen this shard in listShards, so it
// probably fell off the TRIM_HORIZON, and we can assume it's fully processed.
return true
}
select {
case <-ctx.Done():
return false
case <-c:
// the channel has been processed and closed by the consumer (CloseShard has been called)
return true
}
}
// findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing
// a particular shard.
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error {
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
g.shardMu.Lock()
defer g.shardMu.Unlock()
@ -105,39 +72,14 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e
shards, err := listShards(ctx, g.kinesis, g.streamName)
if err != nil {
g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error()))
return err
return
}
// We do two `for` loops, since we have to set up all the `shardClosed`
// channels before we start using any of them. It's highly probable
// that Kinesis provides us the shards in dependency order (parents
// before children), but it doesn't appear to be a guarantee.
for _, shard := range shards {
if _, ok := g.shards[*shard.ShardId]; ok {
continue
}
g.shards[*shard.ShardId] = shard
g.shardsClosed[*shard.ShardId] = make(chan struct{})
shardc <- shard
}
for _, shard := range shards {
shard := shard // Shadow shard, since we use it in goroutine
var parent1, parent2 <-chan struct{}
if shard.ParentShardId != nil {
parent1 = g.shardsClosed[*shard.ParentShardId]
}
if shard.AdjacentParentShardId != nil {
parent2 = g.shardsClosed[*shard.AdjacentParentShardId]
}
go func() {
// Asynchronously wait for all parents of this shard to be processed
// before providing it out to our client. Kinesis guarantees that a
// given partition key's data will be provided to clients in-order,
// but when splits or joins happen, we need to process all parents prior
// to processing children or that ordering guarantee is not maintained.
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
shardc <- shard
}
}()
}
return nil
}

View file

@ -120,11 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
)
go func() {
err := c.group.Start(ctx, shardc)
if err != nil {
errc <- fmt.Errorf("error starting scan: %w", err)
cancel()
}
c.group.Start(ctx, shardc)
<-ctx.Done()
close(shardc)
}()
@ -135,19 +131,13 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
wg.Add(1)
go func(shardID string) {
defer wg.Done()
var err error
if err = c.ScanShard(ctx, shardID, fn); err != nil {
err = fmt.Errorf("shard %s error: %w", shardID, err)
} else if closeable, ok := c.group.(CloseableGroup); !ok {
// group doesn't allow closure, skip calling CloseShard
} else if err = closeable.CloseShard(ctx, shardID); err != nil {
err = fmt.Errorf("shard closed CloseableGroup error: %w", err)
}
if err != nil {
if err := c.ScanShard(ctx, shardID, fn); err != nil {
select {
case errc <- fmt.Errorf("shard %s error: %w", shardID, err):
// first error to occur
cancel()
default:
// error has already occurred
}
}
}(aws.ToString(shard.ShardId))
@ -161,10 +151,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
return <-errc
}
func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error {
return nil
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
@ -254,7 +240,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID))
if c.shardClosedHandler != nil {
if err := c.shardClosedHandler(c.streamName, shardID); err != nil {
err := c.shardClosedHandler(c.streamName, shardID)
if err != nil {
return fmt.Errorf("shard closed handler error: %w", err)
}
}
@ -306,10 +294,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
}
res, err := c.client.GetShardIterator(ctx, params)
if err != nil {
return nil, err
}
return res.ShardIterator, nil
return res.ShardIterator, err
}
func isRetriableError(err error) bool {

View file

@ -2,12 +2,9 @@ package consumer
import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
@ -27,15 +24,6 @@ var records = []types.Record{
},
}
// Implement logger to wrap testing.T.Log.
type testLogger struct {
t *testing.T
}
func (t *testLogger) Log(args ...interface{}) {
t.t.Log(args...)
}
func TestNew(t *testing.T) {
if _, err := New("myStreamName"); err != nil {
t.Fatalf("new consumer error: %v", err)
@ -72,7 +60,6 @@ func TestScan(t *testing.T) {
WithClient(client),
WithCounter(ctr),
WithStore(cp),
WithLogger(&testLogger{t}),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
@ -111,71 +98,6 @@ func TestScan(t *testing.T) {
}
}
func TestScan_ListShardsError(t *testing.T) {
mockError := errors.New("mock list shards error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return nil, mockError
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from listShards, but instead got %v", err)
}
}
func TestScan_GetShardIteratorError(t *testing.T) {
mockError := errors.New("mock get shard iterator error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{ShardId: aws.String("myShard")},
},
}, nil
},
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return nil, mockError
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
}
}
func TestScanShard(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
@ -200,7 +122,6 @@ func TestScanShard(t *testing.T) {
WithClient(client),
WithCounter(ctr),
WithStore(cp),
WithLogger(&testLogger{t}),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
@ -380,8 +301,7 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) {
WithClient(client),
WithShardClosedHandler(func(_, _ string) error {
return fmt.Errorf("closed shard error")
}),
WithLogger(&testLogger{t}))
}))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
@ -415,7 +335,7 @@ func TestScanShard_GetRecordsError(t *testing.T) {
return nil
}
c, err := New("myStreamName", WithClient(client), WithLogger(&testLogger{t}))
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
@ -464,201 +384,3 @@ func (fc *fakeCounter) Add(_ string, count int64) {
fc.counter += count
}
func TestScan_PreviousParentsBeforeTrimHorizon(t *testing.T) {
client := &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{
ShardId: aws.String("myShard"),
ParentShardId: aws.String("myOldParent"),
AdjacentParentShardId: aws.String("myOldAdjacentParent"),
},
},
}, nil
},
}
var (
cp = store.New()
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithStore(cp),
WithLogger(&testLogger{t}),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var (
ctx, cancel = context.WithCancel(context.Background())
res string
)
var fn = func(r *Record) error {
res += string(r.Data)
if string(r.Data) == "lastData" {
cancel()
}
return nil
}
if err := c.Scan(ctx, fn); err != nil {
t.Errorf("scan returned unexpected error %v", err)
}
if res != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "firstDatalastData", res)
}
if val := ctr.Get(); val != 2 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
val, err := cp.GetCheckpoint("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
}
}
func TestScan_ParentChildOrdering(t *testing.T) {
// We create a set of shards where shard1 split into (shard2,shard3), then (shard2,shard3) merged into shard4.
client := &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String(*params.ShardId + "iter"),
}, nil
},
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
switch *params.ShardIterator {
case "shard1iter":
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: []types.Record{
{
Data: []byte("shard1data"),
SequenceNumber: aws.String("shard1num"),
},
},
}, nil
case "shard2iter":
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: []types.Record{},
}, nil
case "shard3iter":
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: []types.Record{
{
Data: []byte("shard3data"),
SequenceNumber: aws.String("shard3num"),
},
},
}, nil
case "shard4iter":
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: []types.Record{
{
Data: []byte("shard4data"),
SequenceNumber: aws.String("shard4num"),
},
},
}, nil
default:
panic("got unexpected iterator")
}
},
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
// Intentionally misorder these to test resiliance to ordering issues from ListShards.
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{
ShardId: aws.String("shard3"),
ParentShardId: aws.String("shard1"),
},
{
ShardId: aws.String("shard1"),
ParentShardId: aws.String("shard0"), // not otherwise referenced, parent ordering should ignore this
},
{
ShardId: aws.String("shard4"),
ParentShardId: aws.String("shard2"),
AdjacentParentShardId: aws.String("shard3"),
},
{
ShardId: aws.String("shard2"),
ParentShardId: aws.String("shard1"),
},
},
}, nil
},
}
var (
cp = store.New()
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithStore(cp),
WithLogger(&testLogger{t}),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var (
ctx, cancel = context.WithCancel(context.Background())
res string
)
rand.Seed(time.Now().UnixNano())
var fn = func(r *Record) error {
res += string(r.Data)
time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond)
if string(r.Data) == "shard4data" {
cancel()
}
return nil
}
if err := c.Scan(ctx, fn); err != nil {
t.Errorf("scan returned unexpected error %v", err)
}
if want := "shard1datashard3datashard4data"; res != want {
t.Errorf("callback error expected %s, got %s", want, res)
}
if val := ctr.Get(); val != 3 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
val, err := cp.GetCheckpoint("myStreamName", "shard4data")
if err != nil && val != "shard4num" {
t.Errorf("checkout error expected %s, got %s", "shard4num", val)
}
}

View file

@ -8,14 +8,7 @@ import (
// Group interface used to manage which shard to process
type Group interface {
Start(ctx context.Context, shardc chan types.Shard) error
Start(ctx context.Context, shardc chan types.Shard)
GetCheckpoint(streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error
}
type CloseableGroup interface {
Group
// Allows shard processors to tell the group when the shard has been
// fully processed. Should be called only once per shardID.
CloseShard(ctx context.Context, shardID string) error
}