kinesis-consumer/consumer.go
gram-signal 6720a01733
Maintain parent/child shard ordering across shard splits/merges. (#155)
Kinesis allows clients to rely on an invariant that, for a given partition key, the order of records added to the stream will be maintained.  IE: given an input `pkey=x,val=1  pkey=x,val=2  pkey=x,val=3`, the values `1,2,3` will be seen in that order when processed by clients, so long as clients are careful.  It does so by putting all records for a single partition key into a single shard, then maintaining ordering within that shard.

However, shards can be split and merge, to distribute load better and handle per-shard throughput limits.  Kinesis does this currently by (one or many times) splitting a single shard into two or by merging two adjacent shards into one.  When this occurs, Kinesis still allows for ordering consistency by detailing shard parent/child relationships within its `listShards` outputs.  A split shard A will create children B and C, both with `ParentShardId=A`.  A merging of shards A and B into C will create a new shard C with `ParentShardId=A,AdjacentParentShardId=B`.  So long as clients fully process all records in parents (including adjacent parents) before processing the new shard, ordering will be maintained.

`kinesis-consumer` currently doesn't do this.  Instead, upon the initial (and subsequent) `listShards` call, all visible shards immediately begin processing.  Considering this case, where shards split, then merge, and each shard `X` contains a single record `rX`:

```
time ->
  B
 / \
A   D
 \ /
  C
```

record `rD` should be processed after both `rB` and `rC` are processed, and both `rB` and `rC` should wait for `rA` to be processed.  By starting goroutines immediately, any ordering of `{rA,rB,rC,rD}` might occur within the original code.

This PR utilizes the `AllGroup` as a book-keeper of fully processed shards, with the `Consumer` calling `CloseShard` once it has finished a shard.  `AllGroup` doesn't release a shard for processing until its parents have fully been processed, and the consumer just processes the shards it receives as it used to.

This PR created a new `CloseableGroup` interface rather than append to the existing `Group` interface to maintain backwards compatibility in existing code that may already implement the `Group` interface elsewhere.  Different `Group` implementations don't get the ordering described above, but the default `Consumer` does.
2024-06-06 08:37:42 -07:00

302 lines
8 KiB
Go

package consumer
import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
"github.com/harlow/kinesis-consumer/internal/deaggregator"
)
// Record wraps the record returned from the Kinesis library and
// extends to include the shard id.
type Record struct {
types.Record
ShardID string
MillisBehindLatest *int64
}
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(streamName string, opts ...Option) (*Consumer, error) {
if streamName == "" {
return nil, errors.New("must provide stream name")
}
// new consumer with noop storage, counter, and logger
c := &Consumer{
streamName: streamName,
initialShardIteratorType: types.ShardIteratorTypeLatest,
store: &noopStore{},
counter: &noopCounter{},
logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
},
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
}
// override defaults
for _, opt := range opts {
opt(c)
}
// default client
if c.client == nil {
cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
log.Fatalf("unable to load SDK config, %v", err)
}
c.client = kinesis.NewFromConfig(cfg)
}
// default group consumes all shards
if c.group == nil {
c.group = NewAllGroup(c.client, c.store, streamName, c.logger)
}
return c, nil
}
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
streamName string
initialShardIteratorType types.ShardIteratorType
initialTimestamp *time.Time
client kinesisClient
counter Counter
group Group
logger Logger
store Store
scanInterval time.Duration
maxRecords int64
isAggregated bool
shardClosedHandler ShardClosedHandler
}
// ScanFunc is the type of the function called for each message read
// from the stream. The record argument contains the original record
// returned from the AWS Kinesis library.
// If an error is returned, scanning stops. The sole exception is when the
// function returns the special value ErrSkipCheckpoint.
type ScanFunc func(*Record) error
// ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that
// the current checkpoint should be skipped skipped. It is not returned
// as an error by any function.
var ErrSkipCheckpoint = errors.New("skip checkpoint")
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
// is passed through to each of the goroutines and called with each message pulled from
// the stream.
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var (
errc = make(chan error, 1)
shardc = make(chan types.Shard, 1)
)
go func() {
c.group.Start(ctx, shardc)
<-ctx.Done()
close(shardc)
}()
wg := new(sync.WaitGroup)
// process each of the shards
for shard := range shardc {
wg.Add(1)
go func(shardID string) {
defer wg.Done()
var err error
if err = c.ScanShard(ctx, shardID, fn); err != nil {
err = fmt.Errorf("shard %s error: %w", shardID, err)
} else if closeable, ok := c.group.(CloseableGroup); !ok {
// group doesn't allow closure, skip calling CloseShard
} else if err = closeable.CloseShard(ctx, shardID); err != nil {
err = fmt.Errorf("shard closed CloseableGroup error: %w", err)
}
if err != nil {
select {
case errc <- fmt.Errorf("shard %s error: %w", shardID, err):
cancel()
default:
}
}
}(aws.ToString(shard.ShardId))
}
go func() {
wg.Wait()
close(errc)
}()
return <-errc
}
func (c *Consumer) scanSingleShard(ctx context.Context, shardID string, fn ScanFunc) error {
return nil
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
// get last seq number from checkpoint
lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID)
if err != nil {
return fmt.Errorf("get checkpoint error: %w", err)
}
// get shard iterator
shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %w", err)
}
c.logger.Log("[CONSUMER] start scan:", shardID, lastSeqNum)
defer func() {
c.logger.Log("[CONSUMER] stop scan:", shardID)
}()
scanTicker := time.NewTicker(c.scanInterval)
defer scanTicker.Stop()
for {
resp, err := c.client.GetRecords(ctx, &kinesis.GetRecordsInput{
Limit: aws.Int32(int32(c.maxRecords)),
ShardIterator: shardIterator,
})
// attempt to recover from GetRecords error
if err != nil {
c.logger.Log("[CONSUMER] get records error:", err.Error())
if !isRetriableError(err) {
return fmt.Errorf("get records error: %v", err.Error())
}
shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %w", err)
}
} else {
// loop over records, call callback func
var records []types.Record
// deaggregate records
if c.isAggregated {
records, err = deaggregateRecords(resp.Records)
if err != nil {
return err
}
} else {
records = resp.Records
}
for _, r := range records {
select {
case <-ctx.Done():
return nil
default:
err := fn(&Record{r, shardID, resp.MillisBehindLatest})
if err != nil && err != ErrSkipCheckpoint {
return err
}
if err != ErrSkipCheckpoint {
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err
}
}
c.counter.Add("records", 1)
lastSeqNum = *r.SequenceNumber
}
}
if isShardClosed(resp.NextShardIterator, shardIterator) {
c.logger.Log("[CONSUMER] shard closed:", shardID)
if c.shardClosedHandler != nil {
if err := c.shardClosedHandler(c.streamName, shardID); err != nil {
return fmt.Errorf("shard closed handler error: %w", err)
}
}
return nil
}
shardIterator = resp.NextShardIterator
}
// Wait for next scan
select {
case <-ctx.Done():
return nil
case <-scanTicker.C:
continue
}
}
}
// temporary conversion func of []types.Record -> DeaggregateRecords([]*types.Record) -> []types.Record
func deaggregateRecords(in []types.Record) ([]types.Record, error) {
var recs []*types.Record
for _, rec := range in {
recs = append(recs, &rec)
}
deagg, err := deaggregator.DeaggregateRecords(recs)
if err != nil {
return nil, err
}
var out []types.Record
for _, rec := range deagg {
out = append(out, *rec)
}
return out, nil
}
func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if seqNum != "" {
params.ShardIteratorType = types.ShardIteratorTypeAfterSequenceNumber
params.StartingSequenceNumber = aws.String(seqNum)
} else if c.initialTimestamp != nil {
params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp
params.Timestamp = c.initialTimestamp
} else {
params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType)
}
res, err := c.client.GetShardIterator(ctx, params)
return res.ShardIterator, err
}
func isRetriableError(err error) bool {
switch err.(type) {
case *types.ExpiredIteratorException:
return true
case *types.ProvisionedThroughputExceededException:
return true
}
return false
}
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}