diff --git a/consumer.go b/consumer.go index bff3fb8..041da9c 100644 --- a/consumer.go +++ b/consumer.go @@ -43,6 +43,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) { scanInterval: 250 * time.Millisecond, maxRecords: 10000, metricRegistry: nil, + numWorkers: 1, } // override defaults @@ -93,6 +94,8 @@ type Consumer struct { maxRecords int64 isAggregated bool shardClosedHandler ShardClosedHandler + numWorkers int + workerPool WorkerPool } // ScanFunc is the type of the function called for each message read @@ -115,25 +118,25 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { defer cancel() var ( - errc = make(chan error, 1) - shardc = make(chan types.Shard, 1) + errC = make(chan error, 1) + shardC = make(chan types.Shard, 1) ) go func() { - c.group.Start(ctx, shardc) + c.group.Start(ctx, shardC) <-ctx.Done() - close(shardc) + close(shardC) }() wg := new(sync.WaitGroup) // process each of the shards - for shard := range shardc { + for shard := range shardC { wg.Add(1) go func(shardID string) { defer wg.Done() if err := c.ScanShard(ctx, shardID, fn); err != nil { select { - case errc <- fmt.Errorf("shard %s error: %w", shardID, err): + case errC <- fmt.Errorf("shard %s error: %w", shardID, err): // first error to occur cancel() default: @@ -145,15 +148,19 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { go func() { wg.Wait() - close(errc) + close(errC) }() - return <-errc + return <-errC } // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { + wp := NewWorkerPool(c.streamName, c.numWorkers, fn) + wp.Start(ctx) + defer wp.Stop() + // get last seq number from checkpoint lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID) if err != nil { @@ -211,7 +218,14 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e case <-ctx.Done(): return nil default: - err := fn(&Record{r, shardID, resp.MillisBehindLatest}) + record := Record{r, shardID, resp.MillisBehindLatest} + wp.Submit(record) + + res := wp.Result() + var err error + if res != nil && res.Err != nil { + err = res.Err + } secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second) collectorMillisBehindLatest. diff --git a/examples/consumer-redis/main.go b/examples/consumer-redis/main.go index 9eeb5d0..be48002 100644 --- a/examples/consumer-redis/main.go +++ b/examples/consumer-redis/main.go @@ -46,6 +46,7 @@ func main() { consumer.WithClient(client), consumer.WithStore(checkpointStore), consumer.WithLogger(slog.Default()), + consumer.WithParallelProcessing(2), ) if err != nil { slog.Error("consumer error", slog.String("error", err.Error())) diff --git a/options.go b/options.go index 444d039..d298864 100644 --- a/options.go +++ b/options.go @@ -91,6 +91,13 @@ func WithAggregation(a bool) Option { } } +// WithParallelProcessing sets the size of the Worker Pool that processes incoming requests. Defaults to 1 +func WithParallelProcessing(numWorkers int) Option { + return func(c *Consumer) { + c.numWorkers = numWorkers + } +} + // WithShardClosedHandler defines a custom handler for closed shards. func WithShardClosedHandler(h ShardClosedHandler) Option { return func(c *Consumer) { diff --git a/worker.go b/worker.go index b78b46c..8ce1945 100644 --- a/worker.go +++ b/worker.go @@ -1 +1,101 @@ package consumer + +import ( + "context" + "fmt" +) + +// Result is the output of the worker. It contains the ID of the worker that processed it, the record itself (mainly to +// maintain the offset that the record has and the error of processing to propagate up. +type Result struct { + Record + WorkerName string + Err error +} + +// WorkerPool allows to parallel process records +type WorkerPool struct { + name string + numWorkers int + fn ScanFunc + recordC chan Record + resultC chan Result +} + +// NewWorkerPool returns an instance of WorkerPool +func NewWorkerPool(name string, numWorkers int, fn ScanFunc) *WorkerPool { + return &WorkerPool{ + name: fmt.Sprintf("wp-%s", name), + numWorkers: numWorkers, + fn: fn, + recordC: make(chan Record, 1), + resultC: make(chan Result, 1), + } +} + +// Start spawns the amount of workers specified in numWorkers and starts them. +func (wp *WorkerPool) Start(ctx context.Context) { + // How do I reopen workers if one fails? + for i := range wp.numWorkers { + name := fmt.Sprintf("%s-worker-%d", wp.name, i) + w := newWorker(name, wp.fn, wp.recordC, wp.resultC) + w.start(ctx) + } +} + +// Stop stops the WorkerPool by closing the channels used for processing. +func (wp *WorkerPool) Stop() { + close(wp.recordC) + close(wp.resultC) +} + +// Submit a new Record for processing +func (wp *WorkerPool) Submit(r Record) { + wp.recordC <- r +} + +// Result returns the Result of the Submit-ed Record after it has been processed. +func (wp *WorkerPool) Result() *Result { + select { + case r := <-wp.resultC: + return &r + default: + return nil + } +} + +type worker struct { + name string + fn ScanFunc + recordC chan Record + resultC chan Result +} + +func newWorker(name string, fn ScanFunc, recordC chan Record, resultC chan Result) *worker { + return &worker{ + name: name, + fn: fn, + recordC: recordC, + resultC: resultC, + } +} + +func (w *worker) start(ctx context.Context) { + go func(ctx context.Context) { + for r := range w.recordC { + select { + case <-ctx.Done(): + return + default: + err := w.fn(&r) + res := Result{ + Record: r, + WorkerName: w.name, + Err: err, + } + + w.resultC <- res + } + } + }(ctx) +}