#185 reverts upstream changes

This commit is contained in:
Alex Senger 2024-09-10 17:22:50 +02:00
parent de0c50cc32
commit 58ce4ba9f5
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
4 changed files with 24 additions and 382 deletions

View file

@ -2,7 +2,6 @@ package consumer
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"sync" "sync"
"time" "time"
@ -16,7 +15,6 @@ func NewAllGroup(kinesis kinesisClient, store Store, streamName string, logger *
return &AllGroup{ return &AllGroup{
kinesis: kinesis, kinesis: kinesis,
shards: make(map[string]types.Shard), shards: make(map[string]types.Shard),
shardsClosed: make(map[string]chan struct{}),
streamName: streamName, streamName: streamName,
slog: logger, slog: logger,
Store: store, Store: store,
@ -34,12 +32,11 @@ type AllGroup struct {
shardMu sync.Mutex shardMu sync.Mutex
shards map[string]types.Shard shards map[string]types.Shard
shardsClosed map[string]chan struct{}
} }
// Start is a blocking operation which will loop and attempt to find new // Start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence. // shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error { func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
// Note: while ticker is a rather naive approach to this problem, // Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. i.e. If we miss a new shard // it actually simplifies a few things. i.e. If we miss a new shard
// while AWS is re-sharding we'll pick it up max 30 seconds later. // while AWS is re-sharding we'll pick it up max 30 seconds later.
@ -52,51 +49,21 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
var ticker = time.NewTicker(30 * time.Second) var ticker = time.NewTicker(30 * time.Second)
for { for {
err := g.findNewShards(ctx, shardc) g.findNewShards(ctx, shardc)
if err != nil {
ticker.Stop()
return err
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
ticker.Stop() ticker.Stop()
return nil return
case <-ticker.C: case <-ticker.C:
} }
} }
} }
func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error {
g.shardMu.Lock()
defer g.shardMu.Unlock()
c, ok := g.shardsClosed[shardID]
if !ok {
return fmt.Errorf("closing unknown shard ID %q", shardID)
}
close(c)
return nil
}
func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
if c == nil {
// no channel means we haven't seen this shard in listShards, so it
// probably fell off the TRIM_HORIZON, and we can assume it's fully processed.
return true
}
select {
case <-ctx.Done():
return false
case <-c:
// the channel has been processed and closed by the consumer (CloseShard has been called)
return true
}
}
// findNewShards pulls the list of shards from the Kinesis API // findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing // and uses a local cache to determine if we are already processing
// a particular shard. // a particular shard.
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error { func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
g.shardMu.Lock() g.shardMu.Lock()
defer g.shardMu.Unlock() defer g.shardMu.Unlock()
@ -105,39 +72,14 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) e
shards, err := listShards(ctx, g.kinesis, g.streamName) shards, err := listShards(ctx, g.kinesis, g.streamName)
if err != nil { if err != nil {
g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error())) g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error()))
return err return
} }
// We do two `for` loops, since we have to set up all the `shardClosed`
// channels before we start using any of them. It's highly probable
// that Kinesis provides us the shards in dependency order (parents
// before children), but it doesn't appear to be a guarantee.
for _, shard := range shards { for _, shard := range shards {
if _, ok := g.shards[*shard.ShardId]; ok { if _, ok := g.shards[*shard.ShardId]; ok {
continue continue
} }
g.shards[*shard.ShardId] = shard g.shards[*shard.ShardId] = shard
g.shardsClosed[*shard.ShardId] = make(chan struct{})
}
for _, shard := range shards {
shard := shard // Shadow shard, since we use it in goroutine
var parent1, parent2 <-chan struct{}
if shard.ParentShardId != nil {
parent1 = g.shardsClosed[*shard.ParentShardId]
}
if shard.AdjacentParentShardId != nil {
parent2 = g.shardsClosed[*shard.AdjacentParentShardId]
}
go func() {
// Asynchronously wait for all parents of this shard to be processed
// before providing it out to our client. Kinesis guarantees that a
// given partition key's data will be provided to clients in-order,
// but when splits or joins happen, we need to process all parents prior
// to processing children or that ordering guarantee is not maintained.
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
shardc <- shard shardc <- shard
} }
}()
}
return nil
} }

View file

@ -120,11 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
) )
go func() { go func() {
err := c.group.Start(ctx, shardc) c.group.Start(ctx, shardc)
if err != nil {
errc <- fmt.Errorf("error starting scan: %w", err)
cancel()
}
<-ctx.Done() <-ctx.Done()
close(shardc) close(shardc)
}() }()
@ -135,19 +131,13 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
wg.Add(1) wg.Add(1)
go func(shardID string) { go func(shardID string) {
defer wg.Done() defer wg.Done()
var err error if err := c.ScanShard(ctx, shardID, fn); err != nil {
if err = c.ScanShard(ctx, shardID, fn); err != nil {
err = fmt.Errorf("shard %s error: %w", shardID, err)
} else if closeable, ok := c.group.(CloseableGroup); !ok {
// group doesn't allow closure, skip calling CloseShard
} else if err = closeable.CloseShard(ctx, shardID); err != nil {
err = fmt.Errorf("shard closed CloseableGroup error: %w", err)
}
if err != nil {
select { select {
case errc <- fmt.Errorf("shard %s error: %w", shardID, err): case errc <- fmt.Errorf("shard %s error: %w", shardID, err):
// first error to occur
cancel() cancel()
default: default:
// error has already occurred
} }
} }
}(aws.ToString(shard.ShardId)) }(aws.ToString(shard.ShardId))
@ -161,10 +151,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
return <-errc return <-errc
} }
func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error {
return nil
}
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
@ -254,7 +240,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID)) c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID))
if c.shardClosedHandler != nil { if c.shardClosedHandler != nil {
if err := c.shardClosedHandler(c.streamName, shardID); err != nil { err := c.shardClosedHandler(c.streamName, shardID)
if err != nil {
return fmt.Errorf("shard closed handler error: %w", err) return fmt.Errorf("shard closed handler error: %w", err)
} }
} }
@ -306,10 +294,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
} }
res, err := c.client.GetShardIterator(ctx, params) res, err := c.client.GetShardIterator(ctx, params)
if err != nil { return res.ShardIterator, err
return nil, err
}
return res.ShardIterator, nil
} }
func isRetriableError(err error) bool { func isRetriableError(err error) bool {

View file

@ -2,12 +2,9 @@ package consumer
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"math/rand"
"sync" "sync"
"testing" "testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/kinesis"
@ -27,15 +24,6 @@ var records = []types.Record{
}, },
} }
// Implement logger to wrap testing.T.Log.
type testLogger struct {
t *testing.T
}
func (t *testLogger) Log(args ...interface{}) {
t.t.Log(args...)
}
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
if _, err := New("myStreamName"); err != nil { if _, err := New("myStreamName"); err != nil {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
@ -72,7 +60,6 @@ func TestScan(t *testing.T) {
WithClient(client), WithClient(client),
WithCounter(ctr), WithCounter(ctr),
WithStore(cp), WithStore(cp),
WithLogger(&testLogger{t}),
) )
if err != nil { if err != nil {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
@ -111,71 +98,6 @@ func TestScan(t *testing.T) {
} }
} }
func TestScan_ListShardsError(t *testing.T) {
mockError := errors.New("mock list shards error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return nil, mockError
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from listShards, but instead got %v", err)
}
}
func TestScan_GetShardIteratorError(t *testing.T) {
mockError := errors.New("mock get shard iterator error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{ShardId: aws.String("myShard")},
},
}, nil
},
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return nil, mockError
},
}
// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
}
}
func TestScanShard(t *testing.T) { func TestScanShard(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
@ -200,7 +122,6 @@ func TestScanShard(t *testing.T) {
WithClient(client), WithClient(client),
WithCounter(ctr), WithCounter(ctr),
WithStore(cp), WithStore(cp),
WithLogger(&testLogger{t}),
) )
if err != nil { if err != nil {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
@ -380,8 +301,7 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) {
WithClient(client), WithClient(client),
WithShardClosedHandler(func(_, _ string) error { WithShardClosedHandler(func(_, _ string) error {
return fmt.Errorf("closed shard error") return fmt.Errorf("closed shard error")
}), }))
WithLogger(&testLogger{t}))
if err != nil { if err != nil {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
@ -415,7 +335,7 @@ func TestScanShard_GetRecordsError(t *testing.T) {
return nil return nil
} }
c, err := New("myStreamName", WithClient(client), WithLogger(&testLogger{t})) c, err := New("myStreamName", WithClient(client))
if err != nil { if err != nil {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
@ -464,201 +384,3 @@ func (fc *fakeCounter) Add(_ string, count int64) {
fc.counter += count fc.counter += count
} }
func TestScan_PreviousParentsBeforeTrimHorizon(t *testing.T) {
client := &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{
ShardId: aws.String("myShard"),
ParentShardId: aws.String("myOldParent"),
AdjacentParentShardId: aws.String("myOldAdjacentParent"),
},
},
}, nil
},
}
var (
cp = store.New()
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithStore(cp),
WithLogger(&testLogger{t}),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var (
ctx, cancel = context.WithCancel(context.Background())
res string
)
var fn = func(r *Record) error {
res += string(r.Data)
if string(r.Data) == "lastData" {
cancel()
}
return nil
}
if err := c.Scan(ctx, fn); err != nil {
t.Errorf("scan returned unexpected error %v", err)
}
if res != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "firstDatalastData", res)
}
if val := ctr.Get(); val != 2 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
val, err := cp.GetCheckpoint("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
}
}
func TestScan_ParentChildOrdering(t *testing.T) {
// We create a set of shards where shard1 split into (shard2,shard3), then (shard2,shard3) merged into shard4.
client := &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String(*params.ShardId + "iter"),
}, nil
},
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
switch *params.ShardIterator {
case "shard1iter":
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: []types.Record{
{
Data: []byte("shard1data"),
SequenceNumber: aws.String("shard1num"),
},
},
}, nil
case "shard2iter":
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: []types.Record{},
}, nil
case "shard3iter":
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: []types.Record{
{
Data: []byte("shard3data"),
SequenceNumber: aws.String("shard3num"),
},
},
}, nil
case "shard4iter":
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: []types.Record{
{
Data: []byte("shard4data"),
SequenceNumber: aws.String("shard4num"),
},
},
}, nil
default:
panic("got unexpected iterator")
}
},
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
// Intentionally misorder these to test resiliance to ordering issues from ListShards.
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{
ShardId: aws.String("shard3"),
ParentShardId: aws.String("shard1"),
},
{
ShardId: aws.String("shard1"),
ParentShardId: aws.String("shard0"), // not otherwise referenced, parent ordering should ignore this
},
{
ShardId: aws.String("shard4"),
ParentShardId: aws.String("shard2"),
AdjacentParentShardId: aws.String("shard3"),
},
{
ShardId: aws.String("shard2"),
ParentShardId: aws.String("shard1"),
},
},
}, nil
},
}
var (
cp = store.New()
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithStore(cp),
WithLogger(&testLogger{t}),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var (
ctx, cancel = context.WithCancel(context.Background())
res string
)
rand.Seed(time.Now().UnixNano())
var fn = func(r *Record) error {
res += string(r.Data)
time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond)
if string(r.Data) == "shard4data" {
cancel()
}
return nil
}
if err := c.Scan(ctx, fn); err != nil {
t.Errorf("scan returned unexpected error %v", err)
}
if want := "shard1datashard3datashard4data"; res != want {
t.Errorf("callback error expected %s, got %s", want, res)
}
if val := ctr.Get(); val != 3 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
val, err := cp.GetCheckpoint("myStreamName", "shard4data")
if err != nil && val != "shard4num" {
t.Errorf("checkout error expected %s, got %s", "shard4num", val)
}
}

View file

@ -8,14 +8,7 @@ import (
// Group interface used to manage which shard to process // Group interface used to manage which shard to process
type Group interface { type Group interface {
Start(ctx context.Context, shardc chan types.Shard) error Start(ctx context.Context, shardc chan types.Shard)
GetCheckpoint(streamName, shardID string) (string, error) GetCheckpoint(streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error SetCheckpoint(streamName, shardID, sequenceNumber string) error
} }
type CloseableGroup interface {
Group
// Allows shard processors to tell the group when the shard has been
// fully processed. Should be called only once per shardID.
CloseShard(ctx context.Context, shardID string) error
}