Merge pull request #206 from alexgridx/202-process-results-correctly
#202 parallelizes batch rocessing
This commit is contained in:
commit
4d112e7a3f
4 changed files with 152 additions and 62 deletions
33
collector.go
33
collector.go
|
|
@ -39,4 +39,37 @@ var (
|
||||||
labelStreamName,
|
labelStreamName,
|
||||||
labelShardID,
|
labelShardID,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
gaugeBatchSize = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||||
|
Namespace: "net",
|
||||||
|
Subsystem: "kinesis",
|
||||||
|
Name: "get_records_result_size",
|
||||||
|
Help: "number of records received from a call to get results",
|
||||||
|
ConstLabels: nil,
|
||||||
|
}, []string{
|
||||||
|
labelStreamName,
|
||||||
|
labelShardID,
|
||||||
|
})
|
||||||
|
|
||||||
|
histogramBatchDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Namespace: "net",
|
||||||
|
Subsystem: "kinesis",
|
||||||
|
Name: "records_processing_duration",
|
||||||
|
Help: "time in seconds it takes to process all of the records that were returned from a get records call",
|
||||||
|
Buckets: []float64{0.1, 0.5, 1, 3, 5, 10, 30, 60},
|
||||||
|
}, []string{
|
||||||
|
labelStreamName,
|
||||||
|
labelShardID,
|
||||||
|
})
|
||||||
|
|
||||||
|
histogramAverageRecordDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Namespace: "net",
|
||||||
|
Subsystem: "kinesis",
|
||||||
|
Name: "average_record_processing_duration",
|
||||||
|
Help: "average time in seconds it takes to process a single record in a batch",
|
||||||
|
Buckets: []float64{0.003, 0.005, 0.01, 0.025, 0.05, 0.1, 1, 3},
|
||||||
|
}, []string{
|
||||||
|
labelStreamName,
|
||||||
|
labelShardID,
|
||||||
|
})
|
||||||
)
|
)
|
||||||
|
|
|
||||||
156
consumer.go
156
consumer.go
|
|
@ -70,6 +70,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
errs = errors.Join(errs, c.metricRegistry.Register(collectorMillisBehindLatest))
|
errs = errors.Join(errs, c.metricRegistry.Register(collectorMillisBehindLatest))
|
||||||
errs = errors.Join(errs, c.metricRegistry.Register(counterEventsConsumed))
|
errs = errors.Join(errs, c.metricRegistry.Register(counterEventsConsumed))
|
||||||
errs = errors.Join(errs, c.metricRegistry.Register(counterCheckpointsWritten))
|
errs = errors.Join(errs, c.metricRegistry.Register(counterCheckpointsWritten))
|
||||||
|
errs = errors.Join(errs, c.metricRegistry.Register(gaugeBatchSize))
|
||||||
|
errs = errors.Join(errs, c.metricRegistry.Register(histogramBatchDuration))
|
||||||
|
errs = errors.Join(errs, c.metricRegistry.Register(histogramAverageRecordDuration))
|
||||||
if errs != nil {
|
if errs != nil {
|
||||||
return nil, errs
|
return nil, errs
|
||||||
}
|
}
|
||||||
|
|
@ -95,7 +98,7 @@ type Consumer struct {
|
||||||
isAggregated bool
|
isAggregated bool
|
||||||
shardClosedHandler ShardClosedHandler
|
shardClosedHandler ShardClosedHandler
|
||||||
numWorkers int
|
numWorkers int
|
||||||
workerPool WorkerPool
|
workerPool *WorkerPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanFunc is the type of the function called for each message read
|
// ScanFunc is the type of the function called for each message read
|
||||||
|
|
@ -157,9 +160,9 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
// ScanShard loops over records on a specific shard, calls the callback func
|
// ScanShard loops over records on a specific shard, calls the callback func
|
||||||
// for each record and checkpoints the progress of scan.
|
// for each record and checkpoints the progress of scan.
|
||||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||||
wp := NewWorkerPool(c.streamName, c.numWorkers, fn)
|
c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn)
|
||||||
wp.Start(ctx)
|
c.workerPool.Start(ctx)
|
||||||
defer wp.Stop()
|
defer c.workerPool.Stop()
|
||||||
|
|
||||||
// get last seq number from checkpoint
|
// get last seq number from checkpoint
|
||||||
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
|
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
|
||||||
|
|
@ -200,55 +203,10 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
return fmt.Errorf("get shard iterator error: %w", err)
|
return fmt.Errorf("get shard iterator error: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// loop over records, call callback func
|
lastSeqNum, err = c.processRecords(ctx, shardID, resp)
|
||||||
var records []types.Record
|
|
||||||
|
|
||||||
// desegregate records
|
|
||||||
if c.isAggregated {
|
|
||||||
records, err = disaggregateRecords(resp.Records)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
records = resp.Records
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range records {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
record := Record{r, shardID, resp.MillisBehindLatest}
|
|
||||||
wp.Submit(record)
|
|
||||||
|
|
||||||
res := wp.Result()
|
|
||||||
var err error
|
|
||||||
if res != nil && res.Err != nil {
|
|
||||||
err = res.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
|
|
||||||
collectorMillisBehindLatest.
|
|
||||||
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
|
||||||
Observe(secondsBehindLatest)
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !errors.Is(err, ErrSkipCheckpoint) {
|
|
||||||
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.counter.Add("checkpoint", 1)
|
|
||||||
counterCheckpointsWritten.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
|
|
||||||
}
|
|
||||||
|
|
||||||
counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
|
|
||||||
c.counter.Add("records", 1)
|
|
||||||
lastSeqNum = *r.SequenceNumber
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||||
c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID))
|
c.logger.DebugContext(ctx, "shard closed", slog.String("shard-id", shardID))
|
||||||
|
|
@ -276,6 +234,104 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput) (string, error) {
|
||||||
|
if len(resp.Records) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
startedAt := time.Now()
|
||||||
|
batchSize := float64(len(resp.Records))
|
||||||
|
gaugeBatchSize.
|
||||||
|
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||||
|
Set(batchSize)
|
||||||
|
|
||||||
|
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
|
||||||
|
collectorMillisBehindLatest.
|
||||||
|
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||||
|
Observe(secondsBehindLatest)
|
||||||
|
|
||||||
|
// loop over records, call callback func
|
||||||
|
var records []types.Record
|
||||||
|
|
||||||
|
// disaggregate records
|
||||||
|
var err error
|
||||||
|
if c.isAggregated {
|
||||||
|
records, err = disaggregateRecords(resp.Records)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
records = resp.Records
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(records) == 0 {
|
||||||
|
// nothing to do here
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// submit in goroutine
|
||||||
|
go func() {
|
||||||
|
for _, r := range records {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
record := Record{r, shardID, resp.MillisBehindLatest}
|
||||||
|
// blocks until someone is ready to pick it up
|
||||||
|
c.workerPool.Submit(record)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for all tasks to be processed
|
||||||
|
numberOfProcessedTasks := 0
|
||||||
|
timeout := 5 * time.Second
|
||||||
|
countDownTimer := time.NewTimer(timeout)
|
||||||
|
for {
|
||||||
|
if numberOfProcessedTasks == len(records) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", nil
|
||||||
|
case <-countDownTimer.C:
|
||||||
|
return "", fmt.Errorf("timeline exceeded while awaiting result from workers")
|
||||||
|
default:
|
||||||
|
res, err := c.workerPool.Result()
|
||||||
|
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
|
||||||
|
return "", err // TODO make it more clever once :)
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrSkipCheckpoint) || res != nil {
|
||||||
|
numberOfProcessedTasks++
|
||||||
|
countDownTimer.Reset(timeout)
|
||||||
|
|
||||||
|
counterEventsConsumed.With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).Inc()
|
||||||
|
c.counter.Add("records", 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we MUST only reach this point if everything is processed
|
||||||
|
lastSeqNum := *records[len(records)-1].SequenceNumber
|
||||||
|
|
||||||
|
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, lastSeqNum); err != nil {
|
||||||
|
return "", fmt.Errorf("set checkpoint error: %w", err)
|
||||||
|
}
|
||||||
|
c.counter.Add("checkpoint", int64(numberOfProcessedTasks))
|
||||||
|
counterCheckpointsWritten.
|
||||||
|
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||||
|
Add(float64(numberOfProcessedTasks))
|
||||||
|
|
||||||
|
duration := time.Since(startedAt).Seconds()
|
||||||
|
histogramBatchDuration.
|
||||||
|
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||||
|
Observe(duration)
|
||||||
|
histogramAverageRecordDuration.
|
||||||
|
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
|
||||||
|
Observe(duration / batchSize)
|
||||||
|
return lastSeqNum, nil
|
||||||
|
}
|
||||||
|
|
||||||
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
|
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
|
||||||
func disaggregateRecords(in []types.Record) ([]types.Record, error) {
|
func disaggregateRecords(in []types.Record) ([]types.Record, error) {
|
||||||
var recs []types.Record
|
var recs []types.Record
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ func TestScan(t *testing.T) {
|
||||||
return &kinesis.GetRecordsOutput{
|
return &kinesis.GetRecordsOutput{
|
||||||
NextShardIterator: nil,
|
NextShardIterator: nil,
|
||||||
Records: records,
|
Records: records,
|
||||||
|
MillisBehindLatest: aws.Int64(int64(1000)),
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
listShardsMock: func(_ context.Context, _ *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
listShardsMock: func(_ context.Context, _ *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
|
||||||
|
|
@ -92,7 +93,7 @@ func TestScan(t *testing.T) {
|
||||||
t.Errorf("counter error expected %d, got %d", 2, val)
|
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err := cp.GetCheckpoint("myStreamName", "myShard")
|
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
|
||||||
if err != nil && val != "lastSeqNum" {
|
if err != nil && val != "lastSeqNum" {
|
||||||
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
|
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||||
}
|
}
|
||||||
|
|
@ -158,7 +159,7 @@ func TestScanShard(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// sets checkpoint
|
// sets checkpoint
|
||||||
val, err := cp.GetCheckpoint("myStreamName", "myShard")
|
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
|
||||||
if err != nil && val != "lastSeqNum" {
|
if err != nil && val != "lastSeqNum" {
|
||||||
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
|
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
|
||||||
}
|
}
|
||||||
|
|
@ -242,7 +243,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
|
||||||
t.Fatalf("scan shard error: %v", err)
|
t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err := cp.GetCheckpoint("myStreamName", "myShard")
|
val, err := cp.GetCheckpoint(ctx, "myStreamName", "myShard")
|
||||||
if err != nil && val != "firstSeqNum" {
|
if err != nil && val != "firstSeqNum" {
|
||||||
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
|
t.Fatalf("checkout error expected %s, got %s", "firstSeqNum", val)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
10
worker.go
10
worker.go
|
|
@ -10,7 +10,7 @@ import (
|
||||||
type Result struct {
|
type Result struct {
|
||||||
Record
|
Record
|
||||||
WorkerName string
|
WorkerName string
|
||||||
Err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// WorkerPool allows to parallel process records
|
// WorkerPool allows to parallel process records
|
||||||
|
|
@ -55,12 +55,12 @@ func (wp *WorkerPool) Submit(r Record) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Result returns the Result of the Submit-ed Record after it has been processed.
|
// Result returns the Result of the Submit-ed Record after it has been processed.
|
||||||
func (wp *WorkerPool) Result() *Result {
|
func (wp *WorkerPool) Result() (Result, error) {
|
||||||
select {
|
select {
|
||||||
case r := <-wp.resultC:
|
case r := <-wp.resultC:
|
||||||
return &r
|
return r, r.err
|
||||||
default:
|
default:
|
||||||
return nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -91,7 +91,7 @@ func (w *worker) start(ctx context.Context) {
|
||||||
res := Result{
|
res := Result{
|
||||||
Record: r,
|
Record: r,
|
||||||
WorkerName: w.name,
|
WorkerName: w.name,
|
||||||
Err: err,
|
err: err,
|
||||||
}
|
}
|
||||||
|
|
||||||
w.resultC <- res
|
w.resultC <- res
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue