Merge branch 'harlow-master'
This commit is contained in:
commit
0ee8cc4623
4 changed files with 382 additions and 24 deletions
68
allgroup.go
68
allgroup.go
|
|
@ -2,6 +2,7 @@ package consumer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -15,6 +16,7 @@ func NewAllGroup(kinesis kinesisClient, store Store, streamName string, logger *
|
||||||
return &AllGroup{
|
return &AllGroup{
|
||||||
kinesis: kinesis,
|
kinesis: kinesis,
|
||||||
shards: make(map[string]types.Shard),
|
shards: make(map[string]types.Shard),
|
||||||
|
shardsClosed: make(map[string]chan struct{}),
|
||||||
streamName: streamName,
|
streamName: streamName,
|
||||||
slog: logger,
|
slog: logger,
|
||||||
Store: store,
|
Store: store,
|
||||||
|
|
@ -32,11 +34,12 @@ type AllGroup struct {
|
||||||
|
|
||||||
shardMu sync.Mutex
|
shardMu sync.Mutex
|
||||||
shards map[string]types.Shard
|
shards map[string]types.Shard
|
||||||
|
shardsClosed map[string]chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start is a blocking operation which will loop and attempt to find new
|
// Start is a blocking operation which will loop and attempt to find new
|
||||||
// shards on a regular cadence.
|
// shards on a regular cadence.
|
||||||
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
|
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
|
||||||
// Note: while ticker is a rather naive approach to this problem,
|
// Note: while ticker is a rather naive approach to this problem,
|
||||||
// it actually simplifies a few things. i.e. If we miss a new shard
|
// it actually simplifies a few things. i.e. If we miss a new shard
|
||||||
// while AWS is re-sharding we'll pick it up max 30 seconds later.
|
// while AWS is re-sharding we'll pick it up max 30 seconds later.
|
||||||
|
|
@ -49,21 +52,51 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
|
||||||
var ticker = time.NewTicker(30 * time.Second)
|
var ticker = time.NewTicker(30 * time.Second)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
g.findNewShards(ctx, shardc)
|
err := g.findNewShards(ctx, shardc)
|
||||||
|
if err != nil {
|
||||||
|
ticker.Stop()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return nil
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *AllGroup) CloseShard(ctx context.Context, shardID string) error {
|
||||||
|
g.shardMu.Lock()
|
||||||
|
defer g.shardMu.Unlock()
|
||||||
|
c, ok := g.shardsClosed[shardID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("closing unknown shard ID %q", shardID)
|
||||||
|
}
|
||||||
|
close(c)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForCloseChannel(ctx context.Context, c <-chan struct{}) bool {
|
||||||
|
if c == nil {
|
||||||
|
// no channel means we haven't seen this shard in listShards, so it
|
||||||
|
// probably fell off the TRIM_HORIZON, and we can assume it's fully processed.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-c:
|
||||||
|
// the channel has been processed and closed by the consumer (CloseShard has been called)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// findNewShards pulls the list of shards from the Kinesis API
|
// findNewShards pulls the list of shards from the Kinesis API
|
||||||
// and uses a local cache to determine if we are already processing
|
// and uses a local cache to determine if we are already processing
|
||||||
// a particular shard.
|
// a particular shard.
|
||||||
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
|
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error {
|
||||||
g.shardMu.Lock()
|
g.shardMu.Lock()
|
||||||
defer g.shardMu.Unlock()
|
defer g.shardMu.Unlock()
|
||||||
|
|
||||||
|
|
@ -72,14 +105,39 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
|
||||||
shards, err := listShards(ctx, g.kinesis, g.streamName)
|
shards, err := listShards(ctx, g.kinesis, g.streamName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error()))
|
g.slog.ErrorContext(ctx, "list shards", slog.String("error", err.Error()))
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We do two `for` loops, since we have to set up all the `shardClosed`
|
||||||
|
// channels before we start using any of them. It's highly probable
|
||||||
|
// that Kinesis provides us the shards in dependency order (parents
|
||||||
|
// before children), but it doesn't appear to be a guarantee.
|
||||||
for _, shard := range shards {
|
for _, shard := range shards {
|
||||||
if _, ok := g.shards[*shard.ShardId]; ok {
|
if _, ok := g.shards[*shard.ShardId]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
g.shards[*shard.ShardId] = shard
|
g.shards[*shard.ShardId] = shard
|
||||||
|
g.shardsClosed[*shard.ShardId] = make(chan struct{})
|
||||||
|
}
|
||||||
|
for _, shard := range shards {
|
||||||
|
shard := shard // Shadow shard, since we use it in goroutine
|
||||||
|
var parent1, parent2 <-chan struct{}
|
||||||
|
if shard.ParentShardId != nil {
|
||||||
|
parent1 = g.shardsClosed[*shard.ParentShardId]
|
||||||
|
}
|
||||||
|
if shard.AdjacentParentShardId != nil {
|
||||||
|
parent2 = g.shardsClosed[*shard.AdjacentParentShardId]
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
// Asynchronously wait for all parents of this shard to be processed
|
||||||
|
// before providing it out to our client. Kinesis guarantees that a
|
||||||
|
// given partition key's data will be provided to clients in-order,
|
||||||
|
// but when splits or joins happen, we need to process all parents prior
|
||||||
|
// to processing children or that ordering guarantee is not maintained.
|
||||||
|
if waitForCloseChannel(ctx, parent1) && waitForCloseChannel(ctx, parent2) {
|
||||||
shardc <- shard
|
shardc <- shard
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
31
consumer.go
31
consumer.go
|
|
@ -120,7 +120,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
c.group.Start(ctx, shardc)
|
err := c.group.Start(ctx, shardc)
|
||||||
|
if err != nil {
|
||||||
|
errc <- fmt.Errorf("error starting scan: %w", err)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
close(shardc)
|
close(shardc)
|
||||||
}()
|
}()
|
||||||
|
|
@ -131,13 +135,19 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(shardID string) {
|
go func(shardID string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := c.ScanShard(ctx, shardID, fn); err != nil {
|
var err error
|
||||||
|
if err = c.ScanShard(ctx, shardID, fn); err != nil {
|
||||||
|
err = fmt.Errorf("shard %s error: %w", shardID, err)
|
||||||
|
} else if closeable, ok := c.group.(CloseableGroup); !ok {
|
||||||
|
// group doesn't allow closure, skip calling CloseShard
|
||||||
|
} else if err = closeable.CloseShard(ctx, shardID); err != nil {
|
||||||
|
err = fmt.Errorf("shard closed CloseableGroup error: %w", err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
select {
|
select {
|
||||||
case errc <- fmt.Errorf("shard %s error: %w", shardID, err):
|
case errc <- fmt.Errorf("shard %s error: %w", shardID, err):
|
||||||
// first error to occur
|
|
||||||
cancel()
|
cancel()
|
||||||
default:
|
default:
|
||||||
// error has already occurred
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(aws.ToString(shard.ShardId))
|
}(aws.ToString(shard.ShardId))
|
||||||
|
|
@ -151,6 +161,10 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
return <-errc
|
return <-errc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ScanShard loops over records on a specific shard, calls the callback func
|
// ScanShard loops over records on a specific shard, calls the callback func
|
||||||
// for each record and checkpoints the progress of scan.
|
// for each record and checkpoints the progress of scan.
|
||||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||||
|
|
@ -240,9 +254,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID))
|
c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID))
|
||||||
|
|
||||||
if c.shardClosedHandler != nil {
|
if c.shardClosedHandler != nil {
|
||||||
err := c.shardClosedHandler(c.streamName, shardID)
|
if err := c.shardClosedHandler(c.streamName, shardID); err != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("shard closed handler error: %w", err)
|
return fmt.Errorf("shard closed handler error: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -294,7 +306,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := c.client.GetShardIterator(ctx, params)
|
res, err := c.client.GetShardIterator(ctx, params)
|
||||||
return res.ShardIterator, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res.ShardIterator, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRetriableError(err error) bool {
|
func isRetriableError(err error) bool {
|
||||||
|
|
|
||||||
282
consumer_test.go
282
consumer_test.go
|
|
@ -2,9 +2,12 @@ package consumer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||||
|
|
@ -24,6 +27,15 @@ var records = []types.Record{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Implement logger to wrap testing.T.Log.
|
||||||
|
type testLogger struct {
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testLogger) Log(args ...interface{}) {
|
||||||
|
t.t.Log(args...)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
if _, err := New("myStreamName"); err != nil {
|
if _, err := New("myStreamName"); err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
|
@ -60,6 +72,7 @@ func TestScan(t *testing.T) {
|
||||||
WithClient(client),
|
WithClient(client),
|
||||||
WithCounter(ctr),
|
WithCounter(ctr),
|
||||||
WithStore(cp),
|
WithStore(cp),
|
||||||
|
WithLogger(&testLogger{t}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
|
@ -98,6 +111,71 @@ func TestScan(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScan_ListShardsError(t *testing.T) {
|
||||||
|
mockError := errors.New("mock list shards error")
|
||||||
|
client := &kinesisClientMock{
|
||||||
|
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||||
|
return nil, mockError
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// use cancel func to signal shutdown
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
|
||||||
|
var res string
|
||||||
|
var fn = func(r *Record) error {
|
||||||
|
res += string(r.Data)
|
||||||
|
cancel() // simulate cancellation while processing first record
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := New("myStreamName", WithClient(client))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.Scan(ctx, fn)
|
||||||
|
if !errors.Is(err, mockError) {
|
||||||
|
t.Errorf("expected an error from listShards, but instead got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScan_GetShardIteratorError(t *testing.T) {
|
||||||
|
mockError := errors.New("mock get shard iterator error")
|
||||||
|
client := &kinesisClientMock{
|
||||||
|
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||||
|
return &kinesis.ListShardsOutput{
|
||||||
|
Shards: []types.Shard{
|
||||||
|
{ShardId: aws.String("myShard")},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return nil, mockError
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// use cancel func to signal shutdown
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
|
||||||
|
var res string
|
||||||
|
var fn = func(r *Record) error {
|
||||||
|
res += string(r.Data)
|
||||||
|
cancel() // simulate cancellation while processing first record
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := New("myStreamName", WithClient(client))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.Scan(ctx, fn)
|
||||||
|
if !errors.Is(err, mockError) {
|
||||||
|
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestScanShard(t *testing.T) {
|
func TestScanShard(t *testing.T) {
|
||||||
var client = &kinesisClientMock{
|
var client = &kinesisClientMock{
|
||||||
getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
|
@ -122,6 +200,7 @@ func TestScanShard(t *testing.T) {
|
||||||
WithClient(client),
|
WithClient(client),
|
||||||
WithCounter(ctr),
|
WithCounter(ctr),
|
||||||
WithStore(cp),
|
WithStore(cp),
|
||||||
|
WithLogger(&testLogger{t}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
|
@ -301,7 +380,8 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) {
|
||||||
WithClient(client),
|
WithClient(client),
|
||||||
WithShardClosedHandler(func(_, _ string) error {
|
WithShardClosedHandler(func(_, _ string) error {
|
||||||
return fmt.Errorf("closed shard error")
|
return fmt.Errorf("closed shard error")
|
||||||
}))
|
}),
|
||||||
|
WithLogger(&testLogger{t}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -335,7 +415,7 @@ func TestScanShard_GetRecordsError(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := New("myStreamName", WithClient(client))
|
c, err := New("myStreamName", WithClient(client), WithLogger(&testLogger{t}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -384,3 +464,201 @@ func (fc *fakeCounter) Add(_ string, count int64) {
|
||||||
|
|
||||||
fc.counter += count
|
fc.counter += count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScan_PreviousParentsBeforeTrimHorizon(t *testing.T) {
|
||||||
|
client := &kinesisClientMock{
|
||||||
|
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: records,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||||
|
return &kinesis.ListShardsOutput{
|
||||||
|
Shards: []types.Shard{
|
||||||
|
{
|
||||||
|
ShardId: aws.String("myShard"),
|
||||||
|
ParentShardId: aws.String("myOldParent"),
|
||||||
|
AdjacentParentShardId: aws.String("myOldAdjacentParent"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
cp = store.New()
|
||||||
|
ctr = &fakeCounter{}
|
||||||
|
)
|
||||||
|
|
||||||
|
c, err := New("myStreamName",
|
||||||
|
WithClient(client),
|
||||||
|
WithCounter(ctr),
|
||||||
|
WithStore(cp),
|
||||||
|
WithLogger(&testLogger{t}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
res string
|
||||||
|
)
|
||||||
|
|
||||||
|
var fn = func(r *Record) error {
|
||||||
|
res += string(r.Data)
|
||||||
|
|
||||||
|
if string(r.Data) == "lastData" {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Scan(ctx, fn); err != nil {
|
||||||
|
t.Errorf("scan returned unexpected error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != "firstDatalastData" {
|
||||||
|
t.Errorf("callback error expected %s, got %s", "firstDatalastData", res)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := ctr.Get(); val != 2 {
|
||||||
|
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := cp.GetCheckpoint("myStreamName", "myShard")
|
||||||
|
if err != nil && val != "lastSeqNum" {
|
||||||
|
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScan_ParentChildOrdering(t *testing.T) {
|
||||||
|
// We create a set of shards where shard1 split into (shard2,shard3), then (shard2,shard3) merged into shard4.
|
||||||
|
client := &kinesisClientMock{
|
||||||
|
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
|
return &kinesis.GetShardIteratorOutput{
|
||||||
|
ShardIterator: aws.String(*params.ShardId + "iter"),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
|
||||||
|
switch *params.ShardIterator {
|
||||||
|
case "shard1iter":
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: []types.Record{
|
||||||
|
{
|
||||||
|
Data: []byte("shard1data"),
|
||||||
|
SequenceNumber: aws.String("shard1num"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
case "shard2iter":
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: []types.Record{},
|
||||||
|
}, nil
|
||||||
|
case "shard3iter":
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: []types.Record{
|
||||||
|
{
|
||||||
|
Data: []byte("shard3data"),
|
||||||
|
SequenceNumber: aws.String("shard3num"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
case "shard4iter":
|
||||||
|
return &kinesis.GetRecordsOutput{
|
||||||
|
NextShardIterator: nil,
|
||||||
|
Records: []types.Record{
|
||||||
|
{
|
||||||
|
Data: []byte("shard4data"),
|
||||||
|
SequenceNumber: aws.String("shard4num"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
panic("got unexpected iterator")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||||
|
// Intentionally misorder these to test resiliance to ordering issues from ListShards.
|
||||||
|
return &kinesis.ListShardsOutput{
|
||||||
|
Shards: []types.Shard{
|
||||||
|
{
|
||||||
|
ShardId: aws.String("shard3"),
|
||||||
|
ParentShardId: aws.String("shard1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ShardId: aws.String("shard1"),
|
||||||
|
ParentShardId: aws.String("shard0"), // not otherwise referenced, parent ordering should ignore this
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ShardId: aws.String("shard4"),
|
||||||
|
ParentShardId: aws.String("shard2"),
|
||||||
|
AdjacentParentShardId: aws.String("shard3"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ShardId: aws.String("shard2"),
|
||||||
|
ParentShardId: aws.String("shard1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
cp = store.New()
|
||||||
|
ctr = &fakeCounter{}
|
||||||
|
)
|
||||||
|
|
||||||
|
c, err := New("myStreamName",
|
||||||
|
WithClient(client),
|
||||||
|
WithCounter(ctr),
|
||||||
|
WithStore(cp),
|
||||||
|
WithLogger(&testLogger{t}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new consumer error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
res string
|
||||||
|
)
|
||||||
|
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
|
||||||
|
var fn = func(r *Record) error {
|
||||||
|
res += string(r.Data)
|
||||||
|
time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond)
|
||||||
|
|
||||||
|
if string(r.Data) == "shard4data" {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Scan(ctx, fn); err != nil {
|
||||||
|
t.Errorf("scan returned unexpected error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if want := "shard1datashard3datashard4data"; res != want {
|
||||||
|
t.Errorf("callback error expected %s, got %s", want, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := ctr.Get(); val != 3 {
|
||||||
|
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := cp.GetCheckpoint("myStreamName", "shard4data")
|
||||||
|
if err != nil && val != "shard4num" {
|
||||||
|
t.Errorf("checkout error expected %s, got %s", "shard4num", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
9
group.go
9
group.go
|
|
@ -8,7 +8,14 @@ import (
|
||||||
|
|
||||||
// Group interface used to manage which shard to process
|
// Group interface used to manage which shard to process
|
||||||
type Group interface {
|
type Group interface {
|
||||||
Start(ctx context.Context, shardc chan types.Shard)
|
Start(ctx context.Context, shardc chan types.Shard) error
|
||||||
GetCheckpoint(streamName, shardID string) (string, error)
|
GetCheckpoint(streamName, shardID string) (string, error)
|
||||||
SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CloseableGroup interface {
|
||||||
|
Group
|
||||||
|
// Allows shard processors to tell the group when the shard has been
|
||||||
|
// fully processed. Should be called only once per shardID.
|
||||||
|
CloseShard(ctx context.Context, shardID string) error
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue