#4 minor cleanups

This commit is contained in:
Alex Senger 2024-04-10 15:16:07 +02:00
parent a5089538fd
commit 040fa06efa
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
5 changed files with 28 additions and 26 deletions

View file

@ -8,7 +8,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/aws/aws-sdk-go-v2/service/kinesis/types"
) )
// NewAllGroup returns an intitialized AllGroup for consuming // NewAllGroup returns an initialized AllGroup for consuming
// all shards on a stream // all shards on a stream
func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup {
return &AllGroup{ return &AllGroup{
@ -38,10 +38,10 @@ type AllGroup struct {
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
// Note: while ticker is a rather naive approach to this problem, // Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. i.e. If we miss a new shard // it actually simplifies a few things. i.e. If we miss a new shard
// while AWS is resharding we'll pick it up max 30 seconds later. // while AWS is re-sharding we'll pick it up max 30 seconds later.
// It might be worth refactoring this flow to allow the consumer to // It might be worth refactoring this flow to allow the consumer to
// to notify the broker when a shard is closed. However, shards don't // notify the broker when a shard is closed. However, shards don't
// necessarily close at the same time, so we could potentially get a // necessarily close at the same time, so we could potentially get a
// thundering heard of notifications from the consumer. // thundering heard of notifications from the consumer.

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"log" "log"
"sync" "sync"
"time" "time"
@ -13,6 +13,7 @@ import (
"github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/aws/aws-sdk-go-v2/service/kinesis/types"
"github.com/harlow/kinesis-consumer/internal/deaggregator" "github.com/harlow/kinesis-consumer/internal/deaggregator"
) )
@ -38,7 +39,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
store: &noopStore{}, store: &noopStore{},
counter: &noopCounter{}, counter: &noopCounter{},
logger: &noopLogger{ logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags), logger: log.New(io.Discard, "", log.LstdFlags),
}, },
scanInterval: 250 * time.Millisecond, scanInterval: 250 * time.Millisecond,
maxRecords: 10000, maxRecords: 10000,
@ -90,7 +91,7 @@ type Consumer struct {
type ScanFunc func(*Record) error type ScanFunc func(*Record) error
// ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that // ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that
// the current checkpoint should be skipped skipped. It is not returned // the current checkpoint should be skipped. It is not returned
// as an error by any function. // as an error by any function.
var ErrSkipCheckpoint = errors.New("skip checkpoint") var ErrSkipCheckpoint = errors.New("skip checkpoint")
@ -183,9 +184,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
// loop over records, call callback func // loop over records, call callback func
var records []types.Record var records []types.Record
// deaggregate records // desegregate records
if c.isAggregated { if c.isAggregated {
records, err = deaggregateRecords(resp.Records) records, err = disaggregateRecords(resp.Records)
if err != nil { if err != nil {
return err return err
} }
@ -199,11 +200,11 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
return nil return nil
default: default:
err := fn(&Record{r, shardID, resp.MillisBehindLatest}) err := fn(&Record{r, shardID, resp.MillisBehindLatest})
if err != nil && err != ErrSkipCheckpoint { if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
return err return err
} }
if err != ErrSkipCheckpoint { if !errors.Is(err, ErrSkipCheckpoint) {
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err return err
} }
@ -240,14 +241,14 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
} }
} }
// temporary conversion func of []types.Record -> DeaggregateRecords([]*types.Record) -> []types.Record // temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
func deaggregateRecords(in []types.Record) ([]types.Record, error) { func disaggregateRecords(in []types.Record) ([]types.Record, error) {
var recs []*types.Record var recs []*types.Record
for _, rec := range in { for _, rec := range in {
recs = append(recs, &rec) recs = append(recs, &rec)
} }
deagg, err := deaggregator.DeaggregateRecords(recs) deagg, err := deaggregator.DisaggregatedRecords(recs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -272,7 +273,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp
params.Timestamp = c.initialTimestamp params.Timestamp = c.initialTimestamp
} else { } else {
params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType) params.ShardIteratorType = c.initialShardIteratorType
} }
res, err := c.client.GetShardIterator(ctx, params) res, err := c.client.GetShardIterator(ctx, params)

View file

@ -12,7 +12,7 @@ import (
rec "github.com/awslabs/kinesis-aggregation/go/records" rec "github.com/awslabs/kinesis-aggregation/go/records"
) )
// Magic File Header for a KPL Aggregated Record // KplMagicHeader Magic File Header for a KPL Aggregated Record
var KplMagicHeader = fmt.Sprintf("%q", []byte("\xf3\x89\x9a\xc2")) var KplMagicHeader = fmt.Sprintf("%q", []byte("\xf3\x89\x9a\xc2"))
const ( const (
@ -20,9 +20,9 @@ const (
DigestSize = 16 // MD5 Message size for protobuf. DigestSize = 16 // MD5 Message size for protobuf.
) )
// DeaggregateRecords takes an array of Kinesis records and expands any Protobuf // DisaggregatedRecords takes an array of Kinesis records and expands any Protobuf
// records within that array, returning an array of all records // records within that array, returning an array of all records
func DeaggregateRecords(records []*types.Record) ([]*types.Record, error) { func DisaggregatedRecords(records []*types.Record) ([]*types.Record, error) {
var isAggregated bool var isAggregated bool
allRecords := make([]*types.Record, 0) allRecords := make([]*types.Record, 0)
for _, record := range records { for _, record := range records {
@ -79,7 +79,7 @@ func DeaggregateRecords(records []*types.Record) ([]*types.Record, error) {
} }
// createUserRecord takes in the partitionKeys of the aggregated record, the individual // createUserRecord takes in the partitionKeys of the aggregated record, the individual
// deaggregated record, and the original aggregated record builds a kinesis.Record and // disaggregated record, and the original aggregated record builds a kinesis.Record and
// returns it // returns it
func createUserRecord(partitionKeys []string, aggRec *rec.Record, record *types.Record) *types.Record { func createUserRecord(partitionKeys []string, aggRec *rec.Record, record *types.Record) *types.Record {
partitionKey := partitionKeys[*aggRec.PartitionKeyIndex] partitionKey := partitionKeys[*aggRec.PartitionKeyIndex]

View file

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
rec "github.com/awslabs/kinesis-aggregation/go/records" rec "github.com/awslabs/kinesis-aggregation/go/records"
deagg "github.com/harlow/kinesis-consumer/internal/deaggregator" deagg "github.com/harlow/kinesis-consumer/internal/deaggregator"
) )
@ -81,7 +82,7 @@ func TestSmallLengthReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) {
smallByte := []byte("No") smallByte := []byte("No")
kr = generateKinesisRecord(smallByte) kr = generateKinesisRecord(smallByte)
krs = append(krs, kr) krs = append(krs, kr)
dars, err := deagg.DeaggregateRecords(krs) dars, err := deagg.DisaggregatedRecords(krs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -108,7 +109,7 @@ func TestNonMatchingMagicHeaderReturnsSingleRecord(t *testing.T) {
krs = append(krs, kr) krs = append(krs, kr)
dars, err := deagg.DeaggregateRecords(krs) dars, err := deagg.DisaggregatedRecords(krs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -118,7 +119,7 @@ func TestNonMatchingMagicHeaderReturnsSingleRecord(t *testing.T) {
assert.Equal(t, 1, len(dars), "Mismatch magic header test should return length of 1.") assert.Equal(t, 1, len(dars), "Mismatch magic header test should return length of 1.")
} }
// This function tests that the DeaggregateRecords function returns the correct number of // This function tests that the DisaggregatedRecords function returns the correct number of
// deaggregated records from a single aggregated record. // deaggregated records from a single aggregated record.
func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) { func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) {
var err error var err error
@ -133,7 +134,7 @@ func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testi
kr = generateKinesisRecord(aggData) kr = generateKinesisRecord(aggData)
krs = append(krs, kr) krs = append(krs, kr)
dars, err := deagg.DeaggregateRecords(krs) dars, err := deagg.DisaggregatedRecords(krs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -162,13 +163,13 @@ func TestRecordAfterMagicHeaderWithLengthLessThanDigestSizeReturnsSingleRecord(t
krs = append(krs, kr) krs = append(krs, kr)
dars, err := deagg.DeaggregateRecords(krs) dars, err := deagg.DisaggregatedRecords(krs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// A byte record with length less than 16 after the magic header should return // A byte record with length less than 16 after the magic header should return
// a single record from DeaggregateRecords // a single record from DisaggregatedRecords
assert.Equal(t, 1, len(dars), "Digest size test should return length of 1.") assert.Equal(t, 1, len(dars), "Digest size test should return length of 1.")
} }
@ -191,7 +192,7 @@ func TestRecordWithMismatchMd5SumReturnsSingleRecord(t *testing.T) {
krs = append(krs, kr) krs = append(krs, kr)
dars, err := deagg.DeaggregateRecords(krs) dars, err := deagg.DisaggregatedRecords(krs)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -4,7 +4,7 @@ import (
"log" "log"
) )
// A Logger is a minimal interface to as a adaptor for external logging library to consumer // A Logger is a minimal interface to as an adaptor for external logging library to consumer
type Logger interface { type Logger interface {
Log(...interface{}) Log(...interface{})
} }