#202 parallelizes batch rocessing

This commit is contained in:
Alex Senger 2024-09-20 11:34:58 +02:00
parent 138b7de381
commit 189f0ff473
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
4 changed files with 152 additions and 62 deletions

View file

@ -39,4 +39,37 @@ var (
labelStreamName,
labelShardID,
})
gaugeBatchSize = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "net",
Subsystem: "kinesis",
Name: "get_records_result_size",
Help: "number of records received from a call to get results",
ConstLabels: nil,
}, []string{
labelStreamName,
labelShardID,
})
histogramBatchDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "net",
Subsystem: "kinesis",
Name: "records_processing_duration",
Help: "time in seconds it takes to process all of the records that were returned from a get records call",
Buckets: []float64{0.1, 0.5, 1, 3, 5, 10, 30, 60},
}, []string{
labelStreamName,
labelShardID,
})
histogramAverageRecordDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "net",
Subsystem: "kinesis",
Name: "average_record_processing_duration",
Help: "average time in seconds it takes to process a single record in a batch",
Buckets: []float64{0.003, 0.005, 0.01, 0.025, 0.05, 0.1, 1, 3},
}, []string{
labelStreamName,
labelShardID,
})
)

View file

@ -70,6 +70,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
errs = errors.Join(errs, c.metricRegistry.Register(collectorMillisBehindLatest))
errs = errors.Join(errs, c.metricRegistry.Register(counterEventsConsumed))
errs = errors.Join(errs, c.metricRegistry.Register(counterCheckpointsWritten))
errs = errors.Join(errs, c.metricRegistry.Register(gaugeBatchSize))
errs = errors.Join(errs, c.metricRegistry.Register(histogramBatchDuration))
errs = errors.Join(errs, c.metricRegistry.Register(histogramAverageRecordDuration))
if errs != nil {
return nil, errs
}
@ -95,7 +98,7 @@ type Consumer struct {
isAggregated bool
shardClosedHandler ShardClosedHandler
numWorkers int
workerPool WorkerPool
workerPool *WorkerPool
}
// ScanFunc is the type of the function called for each message read
@ -157,9 +160,9 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
wp := NewWorkerPool(c.streamName, c.numWorkers, fn)
wp.Start(ctx)
defer wp.Stop()
c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn)
c.workerPool.Start(ctx)
defer c.workerPool.Stop()
// get last seq number from checkpoint
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
@ -200,54 +203,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
return fmt.Errorf("get shard iterator error: %w", err)
}
} else {
// loop over records, call callback func
var records []types.Record
// desegregate records
if c.isAggregated {
records, err = disaggregateRecords(resp.Records)
if err != nil {
return err
}
} else {
records = resp.Records
}
for _, r := range records {
select {
case <-ctx.Done():
return nil
default:
record := Record{r, shardID, resp.MillisBehindLatest}
wp.Submit(record)
res := wp.Result()
var err error
if res != nil && res.Err != nil {
err = res.Err
}
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
collectorMillisBehindLatest.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(secondsBehindLatest)
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
return err
}
if !errors.Is(err, ErrSkipCheckpoint) {
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
return err
}
c.counter.Add("checkpoint", 1)
counterCheckpointsWritten.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
}
counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
c.counter.Add("records", 1)
lastSeqNum = *r.SequenceNumber
}
lastSeqNum, err = c.processRecords(ctx, shardID, resp)
if err != nil {
return err
}
if isShardClosed(resp.NextShardIterator, shardIterator) {
@ -276,6 +234,104 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
}
}
func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput) (string, error) {
if len(resp.Records) == 0 {
return "", nil
}
startedAt := time.Now()
batchSize := float64(len(resp.Records))
gaugeBatchSize.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Set(batchSize)
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
collectorMillisBehindLatest.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(secondsBehindLatest)
// loop over records, call callback func
var records []types.Record
// disaggregate records
var err error
if c.isAggregated {
records, err = disaggregateRecords(resp.Records)
if err != nil {
return "", err
}
} else {
records = resp.Records
}
if len(records) == 0 {
// nothing to do here
return "", nil
}
// submit in goroutine
go func() {
for _, r := range records {
select {
case <-ctx.Done():
return
default:
record := Record{r, shardID, resp.MillisBehindLatest}
// blocks until someone is ready to pick it up
c.workerPool.Submit(record)
}
}
}()
// wait for all tasks to be processed
numberOfProcessedTasks := 0
timeout := 5 * time.Second
countDownTimer := time.NewTimer(timeout)
for {
if numberOfProcessedTasks == len(records) {
break
}
select {
case <-ctx.Done():
return "", nil
case <-countDownTimer.C:
return "", fmt.Errorf("timeline exceeded while awaiting result from workers")
default:
res, err := c.workerPool.Result()
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
return "", err // TODO make it more clever once :)
}
if errors.Is(err, ErrSkipCheckpoint) || res != nil {
numberOfProcessedTasks++
countDownTimer.Reset(timeout)
counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
c.counter.Add("records", 1)
}
}
}
// we MUST only reach this point if everything is processed
lastSeqNum := *records[len(records)-1].SequenceNumber
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, lastSeqNum); err != nil {
return "", fmt.Errorf("set checkpoint error: %w", err)
}
c.counter.Add("checkpoint", int64(numberOfProcessedTasks))
counterCheckpointsWritten.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Add(float64(numberOfProcessedTasks))
duration := time.Since(startedAt).Seconds()
histogramBatchDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration)
histogramAverageRecordDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration / batchSize)
return lastSeqNum, nil
}
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
func disaggregateRecords(in []types.Record) ([]types.Record, error) {
var recs []types.Record

View file

@ -39,8 +39,9 @@ func TestScan(t *testing.T) {
},
getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
NextShardIterator: nil,
Records: records,
MillisBehindLatest: aws.Int64(int64(1000)),
}, nil
},
listShardsMock: func(_ context.Context, _ *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
@ -92,7 +93,7 @@ func TestScan(t *testing.T) {
t.Errorf("counter error expected %d, got %d", 2, val)
}
val, err := cp.GetCheckpoint("myStreamName", "myShard")
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
}
@ -158,7 +159,7 @@ func TestScanShard(t *testing.T) {
}
// sets checkpoint
val, err := cp.GetCheckpoint("myStreamName", "myShard")
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
}
@ -242,7 +243,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
t.Fatalf("scan shard error: %v", err)
}
val, err := cp.GetCheckpoint("myStreamName", "myShard")
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
if err != nil && val != "firstSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
}

View file

@ -10,7 +10,7 @@ import (
type Result struct {
Record
WorkerName string
Err error
err error
}
// WorkerPool allows to parallel process records
@ -55,12 +55,12 @@ func (wp *WorkerPool) Submit(r Record) {
}
// Result returns the Result of the Submit-ed Record after it has been processed.
func (wp *WorkerPool) Result() *Result {
func (wp *WorkerPool) Result() (Result, error) {
select {
case r := <-wp.resultC:
return &r
return r, r.err
default:
return nil
return Result{}, nil
}
}
@ -91,7 +91,7 @@ func (w *worker) start(ctx context.Context) {
res := Result{
Record: r,
WorkerName: w.name,
Err: err,
err: err,
}
w.resultC <- res