#202 cleanup and refactoring

This commit is contained in:
Alex Senger 2025-03-06 15:06:08 +01:00
parent f3fb792745
commit 9c8480b777
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
2 changed files with 33 additions and 49 deletions

View file

@ -46,11 +46,6 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
maxRecords: 10000,
metricRegistry: nil,
numWorkers: 1,
logger: &noopLogger{
logger: log.New(io.Discard, "", log.LstdFlags),
},
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
}
// override defaults
@ -146,8 +141,8 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// process each of the shards
s := newShardsInProcess()
for shard := range shardC {
shardId := aws.ToString(shard.ShardId)
if s.doesShardExist(shardId) {
shardID := aws.ToString(shard.ShardId)
if s.doesShardExist(shardID) {
// safetynet: if shard already in process by another goroutine, just skipping the request
continue
}
@ -174,7 +169,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// error has already occurred
}
}
}(shardId)
}(shardID)
}
go func() {
@ -185,8 +180,8 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
return <-errC
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
// ScanShard loops over records on a specific shard, calls the callback func for each record and checkpoints the
// progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn)
c.workerPool.Start(ctx)
@ -263,20 +258,12 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
}
func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc) (string, error) {
if len(resp.Records) == 0 {
return "", nil
}
startedAt := time.Now()
batchSize := float64(len(resp.Records))
gaugeBatchSize.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Set(batchSize)
secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second)
collectorMillisBehindLatest.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(secondsBehindLatest)
labels := prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}
gaugeBatchSize.With(labels).Set(batchSize)
collectorMillisBehindLatest.With(labels).Observe(secondsBehindLatest)
// loop over records, call callback func
var records []types.Record
@ -297,7 +284,19 @@ func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kin
return "", nil
}
err = c.runWorkers(ctx, shardID, resp, fn, records)
eg, ctx := errgroup.WithContext(ctx)
eg.SetLimit(c.numWorkers)
for _, r := range records {
eg.Go(func() error {
counterEventsConsumed.With(labels).Inc()
err := fn(&Record{Record: r, ShardID: shardID, MillisBehindLatest: resp.MillisBehindLatest})
if !errors.Is(err, ErrSkipCheckpoint) {
return err
}
return nil
})
}
err = eg.Wait()
if err != nil {
return "", err
}
@ -312,37 +311,14 @@ func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kin
numberOfProcessedTasks := len(records)
c.counter.Add("checkpoint", int64(numberOfProcessedTasks))
counterCheckpointsWritten.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Add(float64(numberOfProcessedTasks))
counterCheckpointsWritten.With(labels).Add(float64(numberOfProcessedTasks))
duration := time.Since(startedAt).Seconds()
histogramBatchDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration)
histogramAverageRecordDuration.
With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}).
Observe(duration / batchSize)
histogramAverageRecordDuration.With(labels).Observe(duration / batchSize)
histogramBatchDuration.With(labels).Observe(duration)
return lastSeqNum, nil
}
// runWorkers launches a worker pool to process the records
func (c *Consumer) runWorkers(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc, records []types.Record) error {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.SetLimit(c.numWorkers)
for _, r := range records {
errGroup.Go(func() error {
err := fn(&Record{Record: r, ShardID: shardID, MillisBehindLatest: resp.MillisBehindLatest})
if !errors.Is(err, ErrSkipCheckpoint) {
return err
}
return nil
})
}
return errGroup.Wait()
}
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
func disaggregateRecords(in []types.Record) ([]types.Record, error) {
var recs []types.Record

View file

@ -8,7 +8,15 @@ import (
// Group interface used to manage which shard to process
type Group interface {
Start(ctx context.Context, shardc chan types.Shard)
Start(ctx context.Context, shardc chan types.Shard) error
GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error)
SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error
}
// CloseableGroup extends Group with the ability to close a shard.
type CloseableGroup interface {
Group
// CloseShard allows shard processors to tell the group when the shard has been fully processed. Should be called
// only once per shardID.
CloseShard(ctx context.Context, shardID string) error
}