kinesis-consumer/consumer.go
Alex 0ea3954331
Sync with Upstream (#241)
* Fix ProvisionedThroughputExceededException error (#161)

Fixes #158. Seems the bug was introduced in #155. See #155 (comment)

* fix isRetriableError (#159)

fix issues-158

* fixed concurrent map rw panic for shardsInProgress map (#163)

Co-authored-by: Sanket Deshpande <sanket@clearblade.com>

---------

Co-authored-by: Mikhail Konovalov <4463812+mskonovalov@users.noreply.github.com>
Co-authored-by: lrs <82623629@qq.com>
Co-authored-by: Sanket Deshpande <ssd20072@gmail.com>
Co-authored-by: Sanket Deshpande <sanket@clearblade.com>
2024-10-02 10:03:03 +02:00

425 lines
12 KiB
Go

package consumer
import (
"context"
"errors"
"fmt"
"io"
"log"
"log/slog"
"sync"
"time"
"golang.org/x/sync/errgroup"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
"github.com/awslabs/kinesis-aggregation/go/v2/deaggregator"
"github.com/prometheus/client_golang/prometheus"
)
// Record wraps the record returned from the Kinesis library and
// extends to include the shard id.
type Record struct {
types.Record
ShardID string
MillisBehindLatest *int64
}
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(streamName string, opts ...Option) (*Consumer, error) {
if streamName == "" {
return nil, errors.New("must provide stream name")
}
// new consumer with noop storage, counter, and logger
c := &Consumer{
streamName: streamName,
initialShardIteratorType: types.ShardIteratorTypeLatest,
store: &noopStore{},
counter: &noopCounter{},
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
metricRegistry: nil,
numWorkers: 1,
logger: &noopLogger{
logger: log.New(io.Discard, "", log.LstdFlags),
},
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
}
// override defaults
for _, opt := range opts {
opt(c)
}
// default client
if c.client == nil {
cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
log.Fatalf("unable to load SDK config, %v", err)
}
c.client = kinesis.NewFromConfig(cfg)
}
// default group consumes all shards
if c.group == nil {
c.group = NewAllGroup(c.client, c.store, streamName, c.logger)
}
if c.metricRegistry != nil {
var errs error
errs = errors.Join(errs, c.metricRegistry.Register(collectorMillisBehindLatest))
errs = errors.Join(errs, c.metricRegistry.Register(counterEventsConsumed))
errs = errors.Join(errs, c.metricRegistry.Register(counterCheckpointsWritten))
errs = errors.Join(errs, c.metricRegistry.Register(gaugeBatchSize))
errs = errors.Join(errs, c.metricRegistry.Register(histogramBatchDuration))
errs = errors.Join(errs, c.metricRegistry.Register(histogramAverageRecordDuration))
if errs != nil {
return nil, errs
}
}
return c, nil
}
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
streamName string
initialShardIteratorType types.ShardIteratorType
initialTimestamp *time.Time
client kinesisClient
// Deprecated. Will be removed in favor of prometheus in a future release.
counter Counter
group Group
logger *slog.Logger
metricRegistry prometheus.Registerer
store Store
scanInterval time.Duration
maxRecords int64
isAggregated bool
shardClosedHandler ShardClosedHandler
numWorkers int
workerPool *WorkerPool
}
// ScanFunc is the type of the function called for each message read
// from the stream. The record argument contains the original record
// returned from the AWS Kinesis library.
// If an error is returned, scanning stops. The sole exception is when the
// function returns the special value ErrSkipCheckpoint.
type ScanFunc func(*Record) error
// ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that
// the current checkpoint should be skipped. It is not returned
// as an error by any function.
var ErrSkipCheckpoint = errors.New("skip checkpoint")
// Scan launches a goroutine to process each of the shards in the stream. The ScanFunc
// is passed through to each of the goroutines and called with each message pulled from
// the stream.
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var (
errC = make(chan error, 1)
shardC = make(chan types.Shard, 1)
)
go func() {
err := c.group.Start(ctx, shardC)
if err != nil {
errC <- fmt.Errorf("error starting scan: %w", err)
cancel()
}
<-ctx.Done()
close(shardC)
}()
wg := new(sync.WaitGroup)
// process each of the shards
s := newShardsInProcess()
for shard := range shardC {
shardId := aws.ToString(shard.ShardId)
if s.doesShardExist(shardId) {
// safetynet: if shard already in process by another goroutine, just skipping the request
continue
}
wg.Add(1)
go func(shardID string) {
s.addShard(shardID)
defer func() {
s.deleteShard(shardID)
}()
defer wg.Done()
var err error
if err = c.ScanShard(ctx, shardID, fn); err != nil {
err = fmt.Errorf("shard %s error: %w", shardID, err)
} else if closeable, ok := c.group.(CloseableGroup); !ok {
// group doesn't allow closure, skip calling CloseShard
} else if err = closeable.CloseShard(ctx, shardID); err != nil {
err = fmt.Errorf("shard closed CloseableGroup error: %w", err)
}
if err != nil {
select {
case errC <- fmt.Errorf("shard %s error: %w", shardID, err):
// first error to occur
cancel()
default:
// error has already occurred
}
}
}(shardId)
}
go func() {
wg.Wait()
close(errC)
}()
return <-errC
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn)
c.workerPool.Start(ctx)
defer c.workerPool.Stop()
// get last seq number from checkpoint
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
if err != nil {
return fmt.Errorf("get checkpoint error: %w", err)
}
// get shard iterator
shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %w", err)
}
c.logger.DebugContext(ctx, "start scan", slog.String("shard-id", shardID), slog.String("last-sequence-number", lastSeqNum))
defer func() {
c.logger.DebugContext(ctx, "stop scan", slog.String("shard-id", shardID))
}()
scanTicker := time.NewTicker(c.scanInterval)
defer scanTicker.Stop()
for {
resp, err := c.client.GetRecords(ctx, &kinesis.GetRecordsInput{
Limit: aws.Int32(int32(c.maxRecords)),
ShardIterator: shardIterator,
})
// attempt to recover from GetRecords error
if err != nil {
if !isRetriableError(err) {
return fmt.Errorf("get records error: %v", err.Error())
}
c.logger.WarnContext(ctx, "get records", slog.String("error", err.Error()))
shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %w", err)
}
} else {
lastSeqNum, err = c.processRecords(ctx, shardID, resp, fn)
if err != nil {
return err
}
if isShardClosed(resp.NextShardIterator, shardIterator) {
c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID))
if c.shardClosedHandler != nil {
err := c.shardClosedHandler(c.streamName, shardID)
if err != nil {
return fmt.Errorf("shard closed handler error: %w", err)
}
}
return nil
}
shardIterator = resp.NextShardIterator
}
// Wait for next scan
select {
case <-ctx.Done():
return nil
case <-scanTicker.C:
continue
}
}
}
func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc) (string, error) {
if len(resp.Records) == 0 {
return "", nil
}
startedAt := time.Now()
batchSize := float64(len(resp.Records))
gaugeBatchSize.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Set(batchSize)
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
collectorMillisBehindLatest.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(secondsBehindLatest)
// loop over records, call callback func
var records []types.Record
// disaggregate records
var err error
if c.isAggregated {
records, err = disaggregateRecords(resp.Records)
if err != nil {
return "", err
}
} else {
records = resp.Records
}
if len(records) == 0 {
// nothing to do here
return "", nil
}
err = c.runWorkers(ctx, shardID, resp, fn, records)
if err != nil {
return "", err
}
// we MUST only reach this point if everything is processed
lastSeqNum := *records[len(records)-1].SequenceNumber
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, lastSeqNum); err != nil {
return "", fmt.Errorf("set checkpoint error: %w", err)
}
numberOfProcessedTasks := len(records)
c.counter.Add("checkpoint", int64(numberOfProcessedTasks))
counterCheckpointsWritten.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Add(float64(numberOfProcessedTasks))
duration := time.Since(startedAt).Seconds()
histogramBatchDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration)
histogramAverageRecordDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration / batchSize)
return lastSeqNum, nil
}
// runWorkers launches a worker pool to process the records
func (c *Consumer) runWorkers(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc, records []types.Record) error {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.SetLimit(c.numWorkers)
for _, r := range records {
errGroup.Go(func() error {
err := fn(&Record{Record: r, ShardID: shardID, MillisBehindLatest: resp.MillisBehindLatest})
if !errors.Is(err, ErrSkipCheckpoint) {
return err
}
return nil
})
}
return errGroup.Wait()
}
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
func disaggregateRecords(in []types.Record) ([]types.Record, error) {
var recs []types.Record
recs = append(recs, in...)
deagg, err := deaggregator.DeaggregateRecords(recs)
if err != nil {
return nil, err
}
var out []types.Record
out = append(out, deagg...)
return out, nil
}
func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if seqNum != "" {
params.ShardIteratorType = types.ShardIteratorTypeAfterSequenceNumber
params.StartingSequenceNumber = aws.String(seqNum)
} else if c.initialTimestamp != nil {
params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp
params.Timestamp = c.initialTimestamp
} else {
params.ShardIteratorType = c.initialShardIteratorType
}
res, err := c.client.GetShardIterator(ctx, params)
return res.ShardIterator, err
}
func isRetriableError(err error) bool {
if oe := (*types.ExpiredIteratorException)(nil); errors.As(err, &oe) {
return true
}
if oe := (*types.ProvisionedThroughputExceededException)(nil); errors.As(err, &oe) {
return true
}
return false
}
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}
type shards struct {
*sync.RWMutex
shardsInProcess map[string]struct{}
}
func newShardsInProcess() *shards {
return &shards{
RWMutex: &sync.RWMutex{},
shardsInProcess: make(map[string]struct{}),
}
}
func (s *shards) addShard(shardId string) {
s.Lock()
defer s.Unlock()
s.shardsInProcess[shardId] = struct{}{}
}
func (s *shards) doesShardExist(shardId string) bool {
s.RLock()
defer s.RUnlock()
_, ok := s.shardsInProcess[shardId]
return ok
}
func (s *shards) deleteShard(shardId string) {
s.Lock()
defer s.Unlock()
delete(s.shardsInProcess, shardId)
}