From 040fa06efae27ef3b6f455b6a3304258cc333a26 Mon Sep 17 00:00:00 2001 From: Alex Senger Date: Wed, 10 Apr 2024 15:16:07 +0200 Subject: [PATCH] #4 minor cleanups --- allgroup.go | 6 +++--- consumer.go | 23 +++++++++++----------- internal/deaggregator/deaggregator.go | 8 ++++---- internal/deaggregator/deaggregator_test.go | 15 +++++++------- logger.go | 2 +- 5 files changed, 28 insertions(+), 26 deletions(-) diff --git a/allgroup.go b/allgroup.go index 749a380..e62af47 100644 --- a/allgroup.go +++ b/allgroup.go @@ -8,7 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kinesis/types" ) -// NewAllGroup returns an intitialized AllGroup for consuming +// NewAllGroup returns an initialized AllGroup for consuming // all shards on a stream func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { return &AllGroup{ @@ -38,10 +38,10 @@ type AllGroup struct { func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { // Note: while ticker is a rather naive approach to this problem, // it actually simplifies a few things. i.e. If we miss a new shard - // while AWS is resharding we'll pick it up max 30 seconds later. + // while AWS is re-sharding we'll pick it up max 30 seconds later. // It might be worth refactoring this flow to allow the consumer to - // to notify the broker when a shard is closed. However, shards don't + // notify the broker when a shard is closed. However, shards don't // necessarily close at the same time, so we could potentially get a // thundering heard of notifications from the consumer. diff --git a/consumer.go b/consumer.go index bff0b21..27ff421 100644 --- a/consumer.go +++ b/consumer.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "sync" "time" @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/kinesis/types" + "github.com/harlow/kinesis-consumer/internal/deaggregator" ) @@ -38,7 +39,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) { store: &noopStore{}, counter: &noopCounter{}, logger: &noopLogger{ - logger: log.New(ioutil.Discard, "", log.LstdFlags), + logger: log.New(io.Discard, "", log.LstdFlags), }, scanInterval: 250 * time.Millisecond, maxRecords: 10000, @@ -90,7 +91,7 @@ type Consumer struct { type ScanFunc func(*Record) error // ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that -// the current checkpoint should be skipped skipped. It is not returned +// the current checkpoint should be skipped. It is not returned // as an error by any function. var ErrSkipCheckpoint = errors.New("skip checkpoint") @@ -183,9 +184,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e // loop over records, call callback func var records []types.Record - // deaggregate records + // desegregate records if c.isAggregated { - records, err = deaggregateRecords(resp.Records) + records, err = disaggregateRecords(resp.Records) if err != nil { return err } @@ -199,11 +200,11 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return nil default: err := fn(&Record{r, shardID, resp.MillisBehindLatest}) - if err != nil && err != ErrSkipCheckpoint { + if err != nil && !errors.Is(err, ErrSkipCheckpoint) { return err } - if err != ErrSkipCheckpoint { + if !errors.Is(err, ErrSkipCheckpoint) { if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { return err } @@ -240,14 +241,14 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } } -// temporary conversion func of []types.Record -> DeaggregateRecords([]*types.Record) -> []types.Record -func deaggregateRecords(in []types.Record) ([]types.Record, error) { +// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record +func disaggregateRecords(in []types.Record) ([]types.Record, error) { var recs []*types.Record for _, rec := range in { recs = append(recs, &rec) } - deagg, err := deaggregator.DeaggregateRecords(recs) + deagg, err := deaggregator.DisaggregatedRecords(recs) if err != nil { return nil, err } @@ -272,7 +273,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp params.Timestamp = c.initialTimestamp } else { - params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType) + params.ShardIteratorType = c.initialShardIteratorType } res, err := c.client.GetShardIterator(ctx, params) diff --git a/internal/deaggregator/deaggregator.go b/internal/deaggregator/deaggregator.go index 94782f1..b2395f8 100644 --- a/internal/deaggregator/deaggregator.go +++ b/internal/deaggregator/deaggregator.go @@ -12,7 +12,7 @@ import ( rec "github.com/awslabs/kinesis-aggregation/go/records" ) -// Magic File Header for a KPL Aggregated Record +// KplMagicHeader Magic File Header for a KPL Aggregated Record var KplMagicHeader = fmt.Sprintf("%q", []byte("\xf3\x89\x9a\xc2")) const ( @@ -20,9 +20,9 @@ const ( DigestSize = 16 // MD5 Message size for protobuf. ) -// DeaggregateRecords takes an array of Kinesis records and expands any Protobuf +// DisaggregatedRecords takes an array of Kinesis records and expands any Protobuf // records within that array, returning an array of all records -func DeaggregateRecords(records []*types.Record) ([]*types.Record, error) { +func DisaggregatedRecords(records []*types.Record) ([]*types.Record, error) { var isAggregated bool allRecords := make([]*types.Record, 0) for _, record := range records { @@ -79,7 +79,7 @@ func DeaggregateRecords(records []*types.Record) ([]*types.Record, error) { } // createUserRecord takes in the partitionKeys of the aggregated record, the individual -// deaggregated record, and the original aggregated record builds a kinesis.Record and +// disaggregated record, and the original aggregated record builds a kinesis.Record and // returns it func createUserRecord(partitionKeys []string, aggRec *rec.Record, record *types.Record) *types.Record { partitionKey := partitionKeys[*aggRec.PartitionKeyIndex] diff --git a/internal/deaggregator/deaggregator_test.go b/internal/deaggregator/deaggregator_test.go index d1c33f9..1f2a037 100644 --- a/internal/deaggregator/deaggregator_test.go +++ b/internal/deaggregator/deaggregator_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" rec "github.com/awslabs/kinesis-aggregation/go/records" + deagg "github.com/harlow/kinesis-consumer/internal/deaggregator" ) @@ -81,7 +82,7 @@ func TestSmallLengthReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) { smallByte := []byte("No") kr = generateKinesisRecord(smallByte) krs = append(krs, kr) - dars, err := deagg.DeaggregateRecords(krs) + dars, err := deagg.DisaggregatedRecords(krs) if err != nil { panic(err) } @@ -108,7 +109,7 @@ func TestNonMatchingMagicHeaderReturnsSingleRecord(t *testing.T) { krs = append(krs, kr) - dars, err := deagg.DeaggregateRecords(krs) + dars, err := deagg.DisaggregatedRecords(krs) if err != nil { panic(err) } @@ -118,7 +119,7 @@ func TestNonMatchingMagicHeaderReturnsSingleRecord(t *testing.T) { assert.Equal(t, 1, len(dars), "Mismatch magic header test should return length of 1.") } -// This function tests that the DeaggregateRecords function returns the correct number of +// This function tests that the DisaggregatedRecords function returns the correct number of // deaggregated records from a single aggregated record. func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) { var err error @@ -133,7 +134,7 @@ func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testi kr = generateKinesisRecord(aggData) krs = append(krs, kr) - dars, err := deagg.DeaggregateRecords(krs) + dars, err := deagg.DisaggregatedRecords(krs) if err != nil { panic(err) } @@ -162,13 +163,13 @@ func TestRecordAfterMagicHeaderWithLengthLessThanDigestSizeReturnsSingleRecord(t krs = append(krs, kr) - dars, err := deagg.DeaggregateRecords(krs) + dars, err := deagg.DisaggregatedRecords(krs) if err != nil { panic(err) } // A byte record with length less than 16 after the magic header should return - // a single record from DeaggregateRecords + // a single record from DisaggregatedRecords assert.Equal(t, 1, len(dars), "Digest size test should return length of 1.") } @@ -191,7 +192,7 @@ func TestRecordWithMismatchMd5SumReturnsSingleRecord(t *testing.T) { krs = append(krs, kr) - dars, err := deagg.DeaggregateRecords(krs) + dars, err := deagg.DisaggregatedRecords(krs) if err != nil { panic(err) } diff --git a/logger.go b/logger.go index ab90d2a..f1896c0 100644 --- a/logger.go +++ b/logger.go @@ -4,7 +4,7 @@ import ( "log" ) -// A Logger is a minimal interface to as a adaptor for external logging library to consumer +// A Logger is a minimal interface to as an adaptor for external logging library to consumer type Logger interface { Log(...interface{}) }