#202 parallelizes batch rocessing
This commit is contained in:
parent
138b7de381
commit
189f0ff473
4 changed files with 152 additions and 62 deletions
33
collector.go
33
collector.go
|
|
@ -39,4 +39,37 @@ var (
|
|||
labelStreamName,
|
||||
labelShardID,
|
||||
})
|
||||
|
||||
gaugeBatchSize = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: "net",
|
||||
Subsystem: "kinesis",
|
||||
Name: "get_records_result_size",
|
||||
Help: "number of records received from a call to get results",
|
||||
ConstLabels: nil,
|
||||
}, []string{
|
||||
labelStreamName,
|
||||
labelShardID,
|
||||
})
|
||||
|
||||
histogramBatchDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "net",
|
||||
Subsystem: "kinesis",
|
||||
Name: "records_processing_duration",
|
||||
Help: "time in seconds it takes to process all of the records that were returned from a get records call",
|
||||
Buckets: []float64{0.1, 0.5, 1, 3, 5, 10, 30, 60},
|
||||
}, []string{
|
||||
labelStreamName,
|
||||
labelShardID,
|
||||
})
|
||||
|
||||
histogramAverageRecordDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "net",
|
||||
Subsystem: "kinesis",
|
||||
Name: "average_record_processing_duration",
|
||||
Help: "average time in seconds it takes to process a single record in a batch",
|
||||
Buckets: []float64{0.003, 0.005, 0.01, 0.025, 0.05, 0.1, 1, 3},
|
||||
}, []string{
|
||||
labelStreamName,
|
||||
labelShardID,
|
||||
})
|
||||
)
|
||||
|
|
|
|||
160
consumer.go
160
consumer.go
|
|
@ -70,6 +70,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
|||
errs = errors.Join(errs, c.metricRegistry.Register(collectorMillisBehindLatest))
|
||||
errs = errors.Join(errs, c.metricRegistry.Register(counterEventsConsumed))
|
||||
errs = errors.Join(errs, c.metricRegistry.Register(counterCheckpointsWritten))
|
||||
errs = errors.Join(errs, c.metricRegistry.Register(gaugeBatchSize))
|
||||
errs = errors.Join(errs, c.metricRegistry.Register(histogramBatchDuration))
|
||||
errs = errors.Join(errs, c.metricRegistry.Register(histogramAverageRecordDuration))
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
|
@ -95,7 +98,7 @@ type Consumer struct {
|
|||
isAggregated bool
|
||||
shardClosedHandler ShardClosedHandler
|
||||
numWorkers int
|
||||
workerPool WorkerPool
|
||||
workerPool *WorkerPool
|
||||
}
|
||||
|
||||
// ScanFunc is the type of the function called for each message read
|
||||
|
|
@ -157,9 +160,9 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
// ScanShard loops over records on a specific shard, calls the callback func
|
||||
// for each record and checkpoints the progress of scan.
|
||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||
wp := NewWorkerPool(c.streamName, c.numWorkers, fn)
|
||||
wp.Start(ctx)
|
||||
defer wp.Stop()
|
||||
c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn)
|
||||
c.workerPool.Start(ctx)
|
||||
defer c.workerPool.Stop()
|
||||
|
||||
// get last seq number from checkpoint
|
||||
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
|
||||
|
|
@ -200,54 +203,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
return fmt.Errorf("get shard iterator error: %w", err)
|
||||
}
|
||||
} else {
|
||||
// loop over records, call callback func
|
||||
var records []types.Record
|
||||
|
||||
// desegregate records
|
||||
if c.isAggregated {
|
||||
records, err = disaggregateRecords(resp.Records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
records = resp.Records
|
||||
}
|
||||
|
||||
for _, r := range records {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
record := Record{r, shardID, resp.MillisBehindLatest}
|
||||
wp.Submit(record)
|
||||
|
||||
res := wp.Result()
|
||||
var err error
|
||||
if res != nil && res.Err != nil {
|
||||
err = res.Err
|
||||
}
|
||||
|
||||
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
|
||||
collectorMillisBehindLatest.
|
||||
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||
Observe(secondsBehindLatest)
|
||||
|
||||
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
|
||||
return err
|
||||
}
|
||||
|
||||
if !errors.Is(err, ErrSkipCheckpoint) {
|
||||
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
c.counter.Add("checkpoint", 1)
|
||||
counterCheckpointsWritten.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
|
||||
}
|
||||
|
||||
counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
|
||||
c.counter.Add("records", 1)
|
||||
lastSeqNum = *r.SequenceNumber
|
||||
}
|
||||
lastSeqNum, err = c.processRecords(ctx, shardID, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||
|
|
@ -276,6 +234,104 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput) (string, error) {
|
||||
if len(resp.Records) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
batchSize := float64(len(resp.Records))
|
||||
gaugeBatchSize.
|
||||
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||
Set(batchSize)
|
||||
|
||||
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
|
||||
collectorMillisBehindLatest.
|
||||
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||
Observe(secondsBehindLatest)
|
||||
|
||||
// loop over records, call callback func
|
||||
var records []types.Record
|
||||
|
||||
// disaggregate records
|
||||
var err error
|
||||
if c.isAggregated {
|
||||
records, err = disaggregateRecords(resp.Records)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
records = resp.Records
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
// nothing to do here
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// submit in goroutine
|
||||
go func() {
|
||||
for _, r := range records {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
record := Record{r, shardID, resp.MillisBehindLatest}
|
||||
// blocks until someone is ready to pick it up
|
||||
c.workerPool.Submit(record)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for all tasks to be processed
|
||||
numberOfProcessedTasks := 0
|
||||
timeout := 5 * time.Second
|
||||
countDownTimer := time.NewTimer(timeout)
|
||||
for {
|
||||
if numberOfProcessedTasks == len(records) {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", nil
|
||||
case <-countDownTimer.C:
|
||||
return "", fmt.Errorf("timeline exceeded while awaiting result from workers")
|
||||
default:
|
||||
res, err := c.workerPool.Result()
|
||||
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
|
||||
return "", err // TODO make it more clever once :)
|
||||
}
|
||||
if errors.Is(err, ErrSkipCheckpoint) || res != nil {
|
||||
numberOfProcessedTasks++
|
||||
countDownTimer.Reset(timeout)
|
||||
|
||||
counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
|
||||
c.counter.Add("records", 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// we MUST only reach this point if everything is processed
|
||||
lastSeqNum := *records[len(records)-1].SequenceNumber
|
||||
|
||||
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, lastSeqNum); err != nil {
|
||||
return "", fmt.Errorf("set checkpoint error: %w", err)
|
||||
}
|
||||
c.counter.Add("checkpoint", int64(numberOfProcessedTasks))
|
||||
counterCheckpointsWritten.
|
||||
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||
Add(float64(numberOfProcessedTasks))
|
||||
|
||||
duration := time.Since(startedAt).Seconds()
|
||||
histogramBatchDuration.
|
||||
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||
Observe(duration)
|
||||
histogramAverageRecordDuration.
|
||||
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||
Observe(duration / batchSize)
|
||||
return lastSeqNum, nil
|
||||
}
|
||||
|
||||
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
|
||||
func disaggregateRecords(in []types.Record) ([]types.Record, error) {
|
||||
var recs []types.Record
|
||||
|
|
|
|||
|
|
@ -39,8 +39,9 @@ func TestScan(t *testing.T) {
|
|||
},
|
||||
getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
|
||||
return &kinesis.GetRecordsOutput{
|
||||
NextShardIterator: nil,
|
||||
Records: records,
|
||||
NextShardIterator: nil,
|
||||
Records: records,
|
||||
MillisBehindLatest: aws.Int64(int64(1000)),
|
||||
}, nil
|
||||
},
|
||||
listShardsMock: func(_ context.Context, _ *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||
|
|
@ -92,7 +93,7 @@ func TestScan(t *testing.T) {
|
|||
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||
}
|
||||
|
||||
val, err := cp.GetCheckpoint("myStreamName", "myShard")
|
||||
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
|
||||
if err != nil && val != "lastSeqNum" {
|
||||
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||
}
|
||||
|
|
@ -158,7 +159,7 @@ func TestScanShard(t *testing.T) {
|
|||
}
|
||||
|
||||
// sets checkpoint
|
||||
val, err := cp.GetCheckpoint("myStreamName", "myShard")
|
||||
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
|
||||
if err != nil && val != "lastSeqNum" {
|
||||
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||
}
|
||||
|
|
@ -242,7 +243,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
|
|||
t.Fatalf("scan shard error: %v", err)
|
||||
}
|
||||
|
||||
val, err := cp.GetCheckpoint("myStreamName", "myShard")
|
||||
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
|
||||
if err != nil && val != "firstSeqNum" {
|
||||
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
|
||||
}
|
||||
|
|
|
|||
10
worker.go
10
worker.go
|
|
@ -10,7 +10,7 @@ import (
|
|||
type Result struct {
|
||||
Record
|
||||
WorkerName string
|
||||
Err error
|
||||
err error
|
||||
}
|
||||
|
||||
// WorkerPool allows to parallel process records
|
||||
|
|
@ -55,12 +55,12 @@ func (wp *WorkerPool) Submit(r Record) {
|
|||
}
|
||||
|
||||
// Result returns the Result of the Submit-ed Record after it has been processed.
|
||||
func (wp *WorkerPool) Result() *Result {
|
||||
func (wp *WorkerPool) Result() (Result, error) {
|
||||
select {
|
||||
case r := <-wp.resultC:
|
||||
return &r
|
||||
return r, r.err
|
||||
default:
|
||||
return nil
|
||||
return Result{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ func (w *worker) start(ctx context.Context) {
|
|||
res := Result{
|
||||
Record: r,
|
||||
WorkerName: w.name,
|
||||
Err: err,
|
||||
err: err,
|
||||
}
|
||||
|
||||
w.resultC <- res
|
||||
|
|
|
|||
Loading…
Reference in a new issue