From 189f0ff4738b346d8700b68f0c7d86473cf7487f Mon Sep 17 00:00:00 2001 From: Alex Senger Date: Fri, 20 Sep 2024 11:34:58 +0200 Subject: [PATCH] #202 parallelizes batch rocessing --- collector.go | 33 ++++++++++ consumer.go | 160 ++++++++++++++++++++++++++++++++--------------- consumer_test.go | 11 ++-- worker.go | 10 +-- 4 files changed, 152 insertions(+), 62 deletions(-) diff --git a/collector.go b/collector.go index 6415de1..f626024 100644 --- a/collector.go +++ b/collector.go @@ -39,4 +39,37 @@ var ( labelStreamName, labelShardID, }) + + gaugeBatchSize = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "net", + Subsystem: "kinesis", + Name: "get_records_result_size", + Help: "number of records received from a call to get results", + ConstLabels: nil, + }, []string{ + labelStreamName, + labelShardID, + }) + + histogramBatchDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "net", + Subsystem: "kinesis", + Name: "records_processing_duration", + Help: "time in seconds it takes to process all of the records that were returned from a get records call", + Buckets: []float64{0.1, 0.5, 1, 3, 5, 10, 30, 60}, + }, []string{ + labelStreamName, + labelShardID, + }) + + histogramAverageRecordDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "net", + Subsystem: "kinesis", + Name: "average_record_processing_duration", + Help: "average time in seconds it takes to process a single record in a batch", + Buckets: []float64{0.003, 0.005, 0.01, 0.025, 0.05, 0.1, 1, 3}, + }, []string{ + labelStreamName, + labelShardID, + }) ) diff --git a/consumer.go b/consumer.go index 7c6f6c9..c4c118d 100644 --- a/consumer.go +++ b/consumer.go @@ -70,6 +70,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) { errs = errors.Join(errs, c.metricRegistry.Register(collectorMillisBehindLatest)) errs = errors.Join(errs, c.metricRegistry.Register(counterEventsConsumed)) errs = errors.Join(errs, c.metricRegistry.Register(counterCheckpointsWritten)) + errs = errors.Join(errs, c.metricRegistry.Register(gaugeBatchSize)) + errs = errors.Join(errs, c.metricRegistry.Register(histogramBatchDuration)) + errs = errors.Join(errs, c.metricRegistry.Register(histogramAverageRecordDuration)) if errs != nil { return nil, errs } @@ -95,7 +98,7 @@ type Consumer struct { isAggregated bool shardClosedHandler ShardClosedHandler numWorkers int - workerPool WorkerPool + workerPool *WorkerPool } // ScanFunc is the type of the function called for each message read @@ -157,9 +160,9 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { - wp := NewWorkerPool(c.streamName, c.numWorkers, fn) - wp.Start(ctx) - defer wp.Stop() + c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn) + c.workerPool.Start(ctx) + defer c.workerPool.Stop() // get last seq number from checkpoint lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID) @@ -200,54 +203,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return fmt.Errorf("get shard iterator error: %w", err) } } else { - // loop over records, call callback func - var records []types.Record - - // desegregate records - if c.isAggregated { - records, err = disaggregateRecords(resp.Records) - if err != nil { - return err - } - } else { - records = resp.Records - } - - for _, r := range records { - select { - case <-ctx.Done(): - return nil - default: - record := Record{r, shardID, resp.MillisBehindLatest} - wp.Submit(record) - - res := wp.Result() - var err error - if res != nil && res.Err != nil { - err = res.Err - } - - secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second) - collectorMillisBehindLatest. - With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). - Observe(secondsBehindLatest) - - if err != nil && !errors.Is(err, ErrSkipCheckpoint) { - return err - } - - if !errors.Is(err, ErrSkipCheckpoint) { - if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil { - return err - } - c.counter.Add("checkpoint", 1) - counterCheckpointsWritten.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc() - } - - counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc() - c.counter.Add("records", 1) - lastSeqNum = *r.SequenceNumber - } + lastSeqNum, err = c.processRecords(ctx, shardID, resp) + if err != nil { + return err } if isShardClosed(resp.NextShardIterator, shardIterator) { @@ -276,6 +234,104 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } } +func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput) (string, error) { + if len(resp.Records) == 0 { + return "", nil + } + + startedAt := time.Now() + batchSize := float64(len(resp.Records)) + gaugeBatchSize. + With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). + Set(batchSize) + + secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second) + collectorMillisBehindLatest. + With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). + Observe(secondsBehindLatest) + + // loop over records, call callback func + var records []types.Record + + // disaggregate records + var err error + if c.isAggregated { + records, err = disaggregateRecords(resp.Records) + if err != nil { + return "", err + } + } else { + records = resp.Records + } + + if len(records) == 0 { + // nothing to do here + return "", nil + } + + // submit in goroutine + go func() { + for _, r := range records { + select { + case <-ctx.Done(): + return + default: + record := Record{r, shardID, resp.MillisBehindLatest} + // blocks until someone is ready to pick it up + c.workerPool.Submit(record) + } + } + }() + + // wait for all tasks to be processed + numberOfProcessedTasks := 0 + timeout := 5 * time.Second + countDownTimer := time.NewTimer(timeout) + for { + if numberOfProcessedTasks == len(records) { + break + } + select { + case <-ctx.Done(): + return "", nil + case <-countDownTimer.C: + return "", fmt.Errorf("timeline exceeded while awaiting result from workers") + default: + res, err := c.workerPool.Result() + if err != nil && !errors.Is(err, ErrSkipCheckpoint) { + return "", err // TODO make it more clever once :) + } + if errors.Is(err, ErrSkipCheckpoint) || res != nil { + numberOfProcessedTasks++ + countDownTimer.Reset(timeout) + + counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc() + c.counter.Add("records", 1) + } + } + } + + // we MUST only reach this point if everything is processed + lastSeqNum := *records[len(records)-1].SequenceNumber + + if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, lastSeqNum); err != nil { + return "", fmt.Errorf("set checkpoint error: %w", err) + } + c.counter.Add("checkpoint", int64(numberOfProcessedTasks)) + counterCheckpointsWritten. + With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). + Add(float64(numberOfProcessedTasks)) + + duration := time.Since(startedAt).Seconds() + histogramBatchDuration. + With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). + Observe(duration) + histogramAverageRecordDuration. + With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). + Observe(duration / batchSize) + return lastSeqNum, nil +} + // temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record func disaggregateRecords(in []types.Record) ([]types.Record, error) { var recs []types.Record diff --git a/consumer_test.go b/consumer_test.go index 81e6d48..0f76642 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -39,8 +39,9 @@ func TestScan(t *testing.T) { }, getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ - NextShardIterator: nil, - Records: records, + NextShardIterator: nil, + Records: records, + MillisBehindLatest: aws.Int64(int64(1000)), }, nil }, listShardsMock: func(_ context.Context, _ *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { @@ -92,7 +93,7 @@ func TestScan(t *testing.T) { t.Errorf("counter error expected %d, got %d", 2, val) } - val, err := cp.GetCheckpoint("myStreamName", "myShard") + val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard") if err != nil && val != "lastSeqNum" { t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val) } @@ -158,7 +159,7 @@ func TestScanShard(t *testing.T) { } // sets checkpoint - val, err := cp.GetCheckpoint("myStreamName", "myShard") + val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard") if err != nil && val != "lastSeqNum" { t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val) } @@ -242,7 +243,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) { t.Fatalf("scan shard error: %v", err) } - val, err := cp.GetCheckpoint("myStreamName", "myShard") + val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard") if err != nil && val != "firstSeqNum" { t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val) } diff --git a/worker.go b/worker.go index 8ce1945..f008956 100644 --- a/worker.go +++ b/worker.go @@ -10,7 +10,7 @@ import ( type Result struct { Record WorkerName string - Err error + err error } // WorkerPool allows to parallel process records @@ -55,12 +55,12 @@ func (wp *WorkerPool) Submit(r Record) { } // Result returns the Result of the Submit-ed Record after it has been processed. -func (wp *WorkerPool) Result() *Result { +func (wp *WorkerPool) Result() (Result, error) { select { case r := <-wp.resultC: - return &r + return r, r.err default: - return nil + return Result{}, nil } } @@ -91,7 +91,7 @@ func (w *worker) start(ctx context.Context) { res := Result{ Record: r, WorkerName: w.name, - Err: err, + err: err, } w.resultC <- res