#202 adds parallel processing option

This commit is contained in:
Alex Senger 2024-09-19 12:11:59 +02:00
parent 56940a5e07
commit dc16271ed1
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
4 changed files with 131 additions and 9 deletions

View file

@ -43,6 +43,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
scanInterval: 250 * time.Millisecond, scanInterval: 250 * time.Millisecond,
maxRecords: 10000, maxRecords: 10000,
metricRegistry: nil, metricRegistry: nil,
numWorkers: 1,
} }
// override defaults // override defaults
@ -93,6 +94,8 @@ type Consumer struct {
maxRecords int64 maxRecords int64
isAggregated bool isAggregated bool
shardClosedHandler ShardClosedHandler shardClosedHandler ShardClosedHandler
numWorkers int
workerPool WorkerPool
} }
// ScanFunc is the type of the function called for each message read // ScanFunc is the type of the function called for each message read
@ -115,25 +118,25 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
defer cancel() defer cancel()
var ( var (
errc = make(chan error, 1) errC = make(chan error, 1)
shardc = make(chan types.Shard, 1) shardC = make(chan types.Shard, 1)
) )
go func() { go func() {
c.group.Start(ctx, shardc) c.group.Start(ctx, shardC)
<-ctx.Done() <-ctx.Done()
close(shardc) close(shardC)
}() }()
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
// process each of the shards // process each of the shards
for shard := range shardc { for shard := range shardC {
wg.Add(1) wg.Add(1)
go func(shardID string) { go func(shardID string) {
defer wg.Done() defer wg.Done()
if err := c.ScanShard(ctx, shardID, fn); err != nil { if err := c.ScanShard(ctx, shardID, fn); err != nil {
select { select {
case errc <- fmt.Errorf("shard %s error: %w", shardID, err): case errC <- fmt.Errorf("shard %s error: %w", shardID, err):
// first error to occur // first error to occur
cancel() cancel()
default: default:
@ -145,15 +148,19 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
go func() { go func() {
wg.Wait() wg.Wait()
close(errc) close(errC)
}() }()
return <-errc return <-errC
} }
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
wp := NewWorkerPool(c.streamName, c.numWorkers, fn)
wp.Start(ctx)
defer wp.Stop()
// get last seq number from checkpoint // get last seq number from checkpoint
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID) lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
if err != nil { if err != nil {
@ -211,7 +218,14 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
default: default:
err := fn(&Record{r, shardID, resp.MillisBehindLatest}) record := Record{r, shardID, resp.MillisBehindLatest}
wp.Submit(record)
res := wp.Result()
var err error
if res != nil && res.Err != nil {
err = res.Err
}
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second) secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
collectorMillisBehindLatest. collectorMillisBehindLatest.

View file

@ -46,6 +46,7 @@ func main() {
consumer.WithClient(client), consumer.WithClient(client),
consumer.WithStore(checkpointStore), consumer.WithStore(checkpointStore),
consumer.WithLogger(slog.Default()), consumer.WithLogger(slog.Default()),
consumer.WithParallelProcessing(2),
) )
if err != nil { if err != nil {
slog.Error("consumer error", slog.String("error", err.Error())) slog.Error("consumer error", slog.String("error", err.Error()))

View file

@ -91,6 +91,13 @@ func WithAggregation(a bool) Option {
} }
} }
// WithParallelProcessing sets the size of the Worker Pool that processes incoming requests. Defaults to 1
func WithParallelProcessing(numWorkers int) Option {
return func(c *Consumer) {
c.numWorkers = numWorkers
}
}
// WithShardClosedHandler defines a custom handler for closed shards. // WithShardClosedHandler defines a custom handler for closed shards.
func WithShardClosedHandler(h ShardClosedHandler) Option { func WithShardClosedHandler(h ShardClosedHandler) Option {
return func(c *Consumer) { return func(c *Consumer) {

100
worker.go
View file

@ -1 +1,101 @@
package consumer package consumer
import (
"context"
"fmt"
)
// Result is the output of the worker. It contains the ID of the worker that processed it, the record itself (mainly to
// maintain the offset that the record has and the error of processing to propagate up.
type Result struct {
Record
WorkerName string
Err error
}
// WorkerPool allows to parallel process records
type WorkerPool struct {
name string
numWorkers int
fn ScanFunc
recordC chan Record
resultC chan Result
}
// NewWorkerPool returns an instance of WorkerPool
func NewWorkerPool(name string, numWorkers int, fn ScanFunc) *WorkerPool {
return &WorkerPool{
name: fmt.Sprintf("wp-%s", name),
numWorkers: numWorkers,
fn: fn,
recordC: make(chan Record, 1),
resultC: make(chan Result, 1),
}
}
// Start spawns the amount of workers specified in numWorkers and starts them.
func (wp *WorkerPool) Start(ctx context.Context) {
// How do I reopen workers if one fails?
for i := range wp.numWorkers {
name := fmt.Sprintf("%s-worker-%d", wp.name, i)
w := newWorker(name, wp.fn, wp.recordC, wp.resultC)
w.start(ctx)
}
}
// Stop stops the WorkerPool by closing the channels used for processing.
func (wp *WorkerPool) Stop() {
close(wp.recordC)
close(wp.resultC)
}
// Submit a new Record for processing
func (wp *WorkerPool) Submit(r Record) {
wp.recordC <- r
}
// Result returns the Result of the Submit-ed Record after it has been processed.
func (wp *WorkerPool) Result() *Result {
select {
case r := <-wp.resultC:
return &r
default:
return nil
}
}
type worker struct {
name string
fn ScanFunc
recordC chan Record
resultC chan Result
}
func newWorker(name string, fn ScanFunc, recordC chan Record, resultC chan Result) *worker {
return &worker{
name: name,
fn: fn,
recordC: recordC,
resultC: resultC,
}
}
func (w *worker) start(ctx context.Context) {
go func(ctx context.Context) {
for r := range w.recordC {
select {
case <-ctx.Done():
return
default:
err := w.fn(&r)
res := Result{
Record: r,
WorkerName: w.name,
Err: err,
}
w.resultC <- res
}
}
}(ctx)
}