diff --git a/consumer.go b/consumer.go index fbb827b..7322d68 100644 --- a/consumer.go +++ b/consumer.go @@ -46,11 +46,6 @@ func New(streamName string, opts ...Option) (*Consumer, error) { maxRecords: 10000, metricRegistry: nil, numWorkers: 1, - logger: &noopLogger{ - logger: log.New(io.Discard, "", log.LstdFlags), - }, - scanInterval: 250 * time.Millisecond, - maxRecords: 10000, } // override defaults @@ -146,8 +141,8 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // process each of the shards s := newShardsInProcess() for shard := range shardC { - shardId := aws.ToString(shard.ShardId) - if s.doesShardExist(shardId) { + shardID := aws.ToString(shard.ShardId) + if s.doesShardExist(shardID) { // safetynet: if shard already in process by another goroutine, just skipping the request continue } @@ -174,7 +169,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // error has already occurred } } - }(shardId) + }(shardID) } go func() { @@ -185,8 +180,8 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { return <-errC } -// ScanShard loops over records on a specific shard, calls the callback func -// for each record and checkpoints the progress of scan. +// ScanShard loops over records on a specific shard, calls the callback func for each record and checkpoints the +// progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { c.workerPool = NewWorkerPool(c.streamName, c.numWorkers, fn) c.workerPool.Start(ctx) @@ -263,20 +258,12 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc) (string, error) { - if len(resp.Records) == 0 { - return "", nil - } - startedAt := time.Now() batchSize := float64(len(resp.Records)) - gaugeBatchSize. - With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). - Set(batchSize) - secondsBehindLatest := float64(time.Duration(*resp.MillisBehindLatest)*time.Millisecond) / float64(time.Second) - collectorMillisBehindLatest. - With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). - Observe(secondsBehindLatest) + labels := prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID} + gaugeBatchSize.With(labels).Set(batchSize) + collectorMillisBehindLatest.With(labels).Observe(secondsBehindLatest) // loop over records, call callback func var records []types.Record @@ -297,7 +284,19 @@ func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kin return "", nil } - err = c.runWorkers(ctx, shardID, resp, fn, records) + eg, ctx := errgroup.WithContext(ctx) + eg.SetLimit(c.numWorkers) + for _, r := range records { + eg.Go(func() error { + counterEventsConsumed.With(labels).Inc() + err := fn(&Record{Record: r, ShardID: shardID, MillisBehindLatest: resp.MillisBehindLatest}) + if !errors.Is(err, ErrSkipCheckpoint) { + return err + } + return nil + }) + } + err = eg.Wait() if err != nil { return "", err } @@ -312,37 +311,14 @@ func (c *Consumer) processRecords(ctx context.Context, shardID string, resp *kin numberOfProcessedTasks := len(records) c.counter.Add("checkpoint", int64(numberOfProcessedTasks)) - counterCheckpointsWritten. - With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). - Add(float64(numberOfProcessedTasks)) + counterCheckpointsWritten.With(labels).Add(float64(numberOfProcessedTasks)) duration := time.Since(startedAt).Seconds() - histogramBatchDuration. - With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). - Observe(duration) - histogramAverageRecordDuration. - With(prometheus.Labels{labelStreamName: c.streamName, labelShardID: shardID}). - Observe(duration / batchSize) + histogramAverageRecordDuration.With(labels).Observe(duration / batchSize) + histogramBatchDuration.With(labels).Observe(duration) return lastSeqNum, nil } -// runWorkers launches a worker pool to process the records -func (c *Consumer) runWorkers(ctx context.Context, shardID string, resp *kinesis.GetRecordsOutput, fn ScanFunc, records []types.Record) error { - errGroup, ctx := errgroup.WithContext(ctx) - errGroup.SetLimit(c.numWorkers) - for _, r := range records { - errGroup.Go(func() error { - err := fn(&Record{Record: r, ShardID: shardID, MillisBehindLatest: resp.MillisBehindLatest}) - if !errors.Is(err, ErrSkipCheckpoint) { - return err - } - return nil - }) - } - - return errGroup.Wait() -} - // temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record func disaggregateRecords(in []types.Record) ([]types.Record, error) { var recs []types.Record diff --git a/group.go b/group.go index 0ccc2e0..8ac5229 100644 --- a/group.go +++ b/group.go @@ -8,7 +8,15 @@ import ( // Group interface used to manage which shard to process type Group interface { - Start(ctx context.Context, shardc chan types.Shard) + Start(ctx context.Context, shardc chan types.Shard) error GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error } + +// CloseableGroup extends Group with the ability to close a shard. +type CloseableGroup interface { + Group + // CloseShard allows shard processors to tell the group when the shard has been fully processed. Should be called + // only once per shardID. + CloseShard(ctx context.Context, shardID string) error +}