#202 parallelizes batch rocessing

This commit is contained in:
Alex Senger 2024-09-20 11:34:58 +02:00
parent 138b7de381
commit 189f0ff473
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
4 changed files with 152 additions and 62 deletions

View file

@ -39,4 +39,37 @@ var (
labelStreamName, labelStreamName,
labelShardID, labelShardID,
}) })
gaugeBatchSize = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "net",
Subsystem: "kinesis",
Name: "get_records_result_size",
Help: "number of records received from a call to get results",
ConstLabels: nil,
}, []string{
labelStreamName,
labelShardID,
})
histogramBatchDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "net",
Subsystem: "kinesis",
Name: "records_processing_duration",
Help: "time in seconds it takes to process all of the records that were returned from a get records call",
Buckets: []float64{0.1, 0.5, 1, 3, 5, 10, 30, 60},
}, []string{
labelStreamName,
labelShardID,
})
histogramAverageRecordDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "net",
Subsystem: "kinesis",
Name: "average_record_processing_duration",
Help: "average time in seconds it takes to process a single record in a batch",
Buckets: []float64{0.003, 0.005, 0.01, 0.025, 0.05, 0.1, 1, 3},
}, []string{
labelStreamName,
labelShardID,
})
) )

View file

@ -70,6 +70,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
errs = errors.Join(errs, c.metricRegistry.Register(collectorMillisBehindLatest)) errs = errors.Join(errs, c.metricRegistry.Register(collectorMillisBehindLatest))
errs = errors.Join(errs, c.metricRegistry.Register(counterEventsConsumed)) errs = errors.Join(errs, c.metricRegistry.Register(counterEventsConsumed))
errs = errors.Join(errs, c.metricRegistry.Register(counterCheckpointsWritten)) errs = errors.Join(errs, c.metricRegistry.Register(counterCheckpointsWritten))
errs = errors.Join(errs, c.metricRegistry.Register(gaugeBatchSize))
errs = errors.Join(errs, c.metricRegistry.Register(histogramBatchDuration))
errs = errors.Join(errs, c.metricRegistry.Register(histogramAverageRecordDuration))
if errs != nil { if errs != nil {
return nil, errs return nil, errs
} }
@ -95,7 +98,7 @@ type Consumer struct {
isAggregated bool isAggregated bool
shardClosedHandler ShardClosedHandler shardClosedHandler ShardClosedHandler
numWorkers int numWorkers int
workerPool WorkerPool workerPool *WorkerPool
} }
// ScanFunc is the type of the function called for each message read // ScanFunc is the type of the function called for each message read
@ -157,9 +160,9 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
wp := NewWorkerPool(c.streamName, c.numWorkers, fn) c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn)
wp.Start(ctx) c.workerPool.Start(ctx)
defer wp.Stop() defer c.workerPool.Stop()
// get last seq number from checkpoint // get last seq number from checkpoint
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID) lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
@ -200,55 +203,10 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
return fmt.Errorf("get shard iterator error: %w", err) return fmt.Errorf("get shard iterator error: %w", err)
} }
} else { } else {
// loop over records, call callback func lastSeqNum, err = c.processRecords(ctx, shardID, resp)
var records []types.Record
// desegregate records
if c.isAggregated {
records, err = disaggregateRecords(resp.Records)
if err != nil { if err != nil {
return err return err
} }
} else {
records = resp.Records
}
for _, r := range records {
select {
case <-ctx.Done():
return nil
default:
record := Record{r, shardID, resp.MillisBehindLatest}
wp.Submit(record)
res := wp.Result()
var err error
if res != nil && res.Err != nil {
err = res.Err
}
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
collectorMillisBehindLatest.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(secondsBehindLatest)
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
return err
}
if !errors.Is(err, ErrSkipCheckpoint) {
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
return err
}
c.counter.Add("checkpoint", 1)
counterCheckpointsWritten.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
}
counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
c.counter.Add("records", 1)
lastSeqNum = *r.SequenceNumber
}
}
if isShardClosed(resp.NextShardIterator, shardIterator) { if isShardClosed(resp.NextShardIterator, shardIterator) {
c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID)) c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID))
@ -276,6 +234,104 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
} }
} }
func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput) (string, error) {
if len(resp.Records) == 0 {
return "", nil
}
startedAt := time.Now()
batchSize := float64(len(resp.Records))
gaugeBatchSize.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Set(batchSize)
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
collectorMillisBehindLatest.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(secondsBehindLatest)
// loop over records, call callback func
var records []types.Record
// disaggregate records
var err error
if c.isAggregated {
records, err = disaggregateRecords(resp.Records)
if err != nil {
return "", err
}
} else {
records = resp.Records
}
if len(records) == 0 {
// nothing to do here
return "", nil
}
// submit in goroutine
go func() {
for _, r := range records {
select {
case <-ctx.Done():
return
default:
record := Record{r, shardID, resp.MillisBehindLatest}
// blocks until someone is ready to pick it up
c.workerPool.Submit(record)
}
}
}()
// wait for all tasks to be processed
numberOfProcessedTasks := 0
timeout := 5 * time.Second
countDownTimer := time.NewTimer(timeout)
for {
if numberOfProcessedTasks == len(records) {
break
}
select {
case <-ctx.Done():
return "", nil
case <-countDownTimer.C:
return "", fmt.Errorf("timeline exceeded while awaiting result from workers")
default:
res, err := c.workerPool.Result()
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
return "", err // TODO make it more clever once :)
}
if errors.Is(err, ErrSkipCheckpoint) || res != nil {
numberOfProcessedTasks++
countDownTimer.Reset(timeout)
counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
c.counter.Add("records", 1)
}
}
}
// we MUST only reach this point if everything is processed
lastSeqNum := *records[len(records)-1].SequenceNumber
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, lastSeqNum); err != nil {
return "", fmt.Errorf("set checkpoint error: %w", err)
}
c.counter.Add("checkpoint", int64(numberOfProcessedTasks))
counterCheckpointsWritten.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Add(float64(numberOfProcessedTasks))
duration := time.Since(startedAt).Seconds()
histogramBatchDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration)
histogramAverageRecordDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration / batchSize)
return lastSeqNum, nil
}
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record // temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
func disaggregateRecords(in []types.Record) ([]types.Record, error) { func disaggregateRecords(in []types.Record) ([]types.Record, error) {
var recs []types.Record var recs []types.Record

View file

@ -41,6 +41,7 @@ func TestScan(t *testing.T) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: records, Records: records,
MillisBehindLatest: aws.Int64(int64(1000)),
}, nil }, nil
}, },
listShardsMock: func(_ context.Context, _ *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { listShardsMock: func(_ context.Context, _ *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
@ -92,7 +93,7 @@ func TestScan(t *testing.T) {
t.Errorf("counter error expected %d, got %d", 2, val) t.Errorf("counter error expected %d, got %d", 2, val)
} }
val, err := cp.GetCheckpoint("myStreamName", "myShard") val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
if err != nil && val != "lastSeqNum" { if err != nil && val != "lastSeqNum" {
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val) t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
} }
@ -158,7 +159,7 @@ func TestScanShard(t *testing.T) {
} }
// sets checkpoint // sets checkpoint
val, err := cp.GetCheckpoint("myStreamName", "myShard") val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
if err != nil && val != "lastSeqNum" { if err != nil && val != "lastSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val) t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
} }
@ -242,7 +243,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
t.Fatalf("scan shard error: %v", err) t.Fatalf("scan shard error: %v", err)
} }
val, err := cp.GetCheckpoint("myStreamName", "myShard") val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
if err != nil && val != "firstSeqNum" { if err != nil && val != "firstSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val) t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
} }

View file

@ -10,7 +10,7 @@ import (
type Result struct { type Result struct {
Record Record
WorkerName string WorkerName string
Err error err error
} }
// WorkerPool allows to parallel process records // WorkerPool allows to parallel process records
@ -55,12 +55,12 @@ func (wp *WorkerPool) Submit(r Record) {
} }
// Result returns the Result of the Submit-ed Record after it has been processed. // Result returns the Result of the Submit-ed Record after it has been processed.
func (wp *WorkerPool) Result() *Result { func (wp *WorkerPool) Result() (Result, error) {
select { select {
case r := <-wp.resultC: case r := <-wp.resultC:
return &r return r, r.err
default: default:
return nil return Result{}, nil
} }
} }
@ -91,7 +91,7 @@ func (w *worker) start(ctx context.Context) {
res := Result{ res := Result{
Record: r, Record: r,
WorkerName: w.name, WorkerName: w.name,
Err: err, err: err,
} }
w.resultC <- res w.resultC <- res