#202 cleanup and refactoring

This commit is contained in:
Alex Senger 2025-03-06 15:06:08 +01:00
parent f3fb792745
commit 9c8480b777
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
2 changed files with 33 additions and 49 deletions

View file

@ -46,11 +46,6 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
maxRecords: 10000, maxRecords: 10000,
metricRegistry: nil, metricRegistry: nil,
numWorkers: 1, numWorkers: 1,
logger: &noopLogger{
logger: log.New(io.Discard, "", log.LstdFlags),
},
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
} }
// override defaults // override defaults
@ -146,8 +141,8 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// process each of the shards // process each of the shards
s := newShardsInProcess() s := newShardsInProcess()
for shard := range shardC { for shard := range shardC {
shardId := aws.ToString(shard.ShardId) shardID := aws.ToString(shard.ShardId)
if s.doesShardExist(shardId) { if s.doesShardExist(shardID) {
// safetynet: if shard already in process by another goroutine, just skipping the request // safetynet: if shard already in process by another goroutine, just skipping the request
continue continue
} }
@ -174,7 +169,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// error has already occurred // error has already occurred
} }
} }
}(shardId) }(shardID)
} }
go func() { go func() {
@ -185,8 +180,8 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
return <-errC return <-errC
} }
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func for each record and checkpoints the
// for each record and checkpoints the progress of scan. // progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn) c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn)
c.workerPool.Start(ctx) c.workerPool.Start(ctx)
@ -263,20 +258,12 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
} }
func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc) (string, error) { func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc) (string, error) {
if len(resp.Records) == 0 {
return "", nil
}
startedAt := time.Now() startedAt := time.Now()
batchSize := float64(len(resp.Records)) batchSize := float64(len(resp.Records))
gaugeBatchSize.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Set(batchSize)
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second) secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
collectorMillisBehindLatest. labels := prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). gaugeBatchSize.With(labels).Set(batchSize)
Observe(secondsBehindLatest) collectorMillisBehindLatest.With(labels).Observe(secondsBehindLatest)
// loop over records, call callback func // loop over records, call callback func
var records []types.Record var records []types.Record
@ -297,7 +284,19 @@ func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kin
return "", nil return "", nil
} }
err = c.runWorkers(ctx, shardID, resp, fn, records) eg, ctx := errgroup.WithContext(ctx)
eg.SetLimit(c.numWorkers)
for _, r := range records {
eg.Go(func() error {
counterEventsConsumed.With(labels).Inc()
err := fn(&Record{Record: r, ShardID: shardID, MillisBehindLatest: resp.MillisBehindLatest})
if !errors.Is(err, ErrSkipCheckpoint) {
return err
}
return nil
})
}
err = eg.Wait()
if err != nil { if err != nil {
return "", err return "", err
} }
@ -312,37 +311,14 @@ func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kin
numberOfProcessedTasks := len(records) numberOfProcessedTasks := len(records)
c.counter.Add("checkpoint", int64(numberOfProcessedTasks)) c.counter.Add("checkpoint", int64(numberOfProcessedTasks))
counterCheckpointsWritten. counterCheckpointsWritten.With(labels).Add(float64(numberOfProcessedTasks))
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Add(float64(numberOfProcessedTasks))
duration := time.Since(startedAt).Seconds() duration := time.Since(startedAt).Seconds()
histogramBatchDuration. histogramAverageRecordDuration.With(labels).Observe(duration / batchSize)
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). histogramBatchDuration.With(labels).Observe(duration)
Observe(duration)
histogramAverageRecordDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration / batchSize)
return lastSeqNum, nil return lastSeqNum, nil
} }
// runWorkers launches a worker pool to process the records
func (c *Consumer) runWorkers(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc, records []types.Record) error {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.SetLimit(c.numWorkers)
for _, r := range records {
errGroup.Go(func() error {
err := fn(&Record{Record: r, ShardID: shardID, MillisBehindLatest: resp.MillisBehindLatest})
if !errors.Is(err, ErrSkipCheckpoint) {
return err
}
return nil
})
}
return errGroup.Wait()
}
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record // temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
func disaggregateRecords(in []types.Record) ([]types.Record, error) { func disaggregateRecords(in []types.Record) ([]types.Record, error) {
var recs []types.Record var recs []types.Record

View file

@ -8,7 +8,15 @@ import (
// Group interface used to manage which shard to process // Group interface used to manage which shard to process
type Group interface { type Group interface {
Start(ctx context.Context, shardc chan types.Shard) Start(ctx context.Context, shardc chan types.Shard) error
GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error)
SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error
} }
// CloseableGroup extends Group with the ability to close a shard.
type CloseableGroup interface {
Group
// CloseShard allows shard processors to tell the group when the shard has been fully processed. Should be called
// only once per shardID.
CloseShard(ctx context.Context, shardID string) error
}