#4 minor cleanups
This commit is contained in:
parent
a5089538fd
commit
040fa06efa
5 changed files with 28 additions and 26 deletions
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
|
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAllGroup returns an intitialized AllGroup for consuming
|
// NewAllGroup returns an initialized AllGroup for consuming
|
||||||
// all shards on a stream
|
// all shards on a stream
|
||||||
func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup {
|
func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup {
|
||||||
return &AllGroup{
|
return &AllGroup{
|
||||||
|
|
@ -38,10 +38,10 @@ type AllGroup struct {
|
||||||
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
|
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
|
||||||
// Note: while ticker is a rather naive approach to this problem,
|
// Note: while ticker is a rather naive approach to this problem,
|
||||||
// it actually simplifies a few things. i.e. If we miss a new shard
|
// it actually simplifies a few things. i.e. If we miss a new shard
|
||||||
// while AWS is resharding we'll pick it up max 30 seconds later.
|
// while AWS is re-sharding we'll pick it up max 30 seconds later.
|
||||||
|
|
||||||
// It might be worth refactoring this flow to allow the consumer to
|
// It might be worth refactoring this flow to allow the consumer to
|
||||||
// to notify the broker when a shard is closed. However, shards don't
|
// notify the broker when a shard is closed. However, shards don't
|
||||||
// necessarily close at the same time, so we could potentially get a
|
// necessarily close at the same time, so we could potentially get a
|
||||||
// thundering heard of notifications from the consumer.
|
// thundering heard of notifications from the consumer.
|
||||||
|
|
||||||
|
|
|
||||||
23
consumer.go
23
consumer.go
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/aws/aws-sdk-go-v2/config"
|
"github.com/aws/aws-sdk-go-v2/config"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
"github.com/aws/aws-sdk-go-v2/service/kinesis"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
|
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
|
||||||
|
|
||||||
"github.com/harlow/kinesis-consumer/internal/deaggregator"
|
"github.com/harlow/kinesis-consumer/internal/deaggregator"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,7 +39,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
store: &noopStore{},
|
store: &noopStore{},
|
||||||
counter: &noopCounter{},
|
counter: &noopCounter{},
|
||||||
logger: &noopLogger{
|
logger: &noopLogger{
|
||||||
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
logger: log.New(io.Discard, "", log.LstdFlags),
|
||||||
},
|
},
|
||||||
scanInterval: 250 * time.Millisecond,
|
scanInterval: 250 * time.Millisecond,
|
||||||
maxRecords: 10000,
|
maxRecords: 10000,
|
||||||
|
|
@ -90,7 +91,7 @@ type Consumer struct {
|
||||||
type ScanFunc func(*Record) error
|
type ScanFunc func(*Record) error
|
||||||
|
|
||||||
// ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that
|
// ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that
|
||||||
// the current checkpoint should be skipped skipped. It is not returned
|
// the current checkpoint should be skipped. It is not returned
|
||||||
// as an error by any function.
|
// as an error by any function.
|
||||||
var ErrSkipCheckpoint = errors.New("skip checkpoint")
|
var ErrSkipCheckpoint = errors.New("skip checkpoint")
|
||||||
|
|
||||||
|
|
@ -183,9 +184,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
// loop over records, call callback func
|
// loop over records, call callback func
|
||||||
var records []types.Record
|
var records []types.Record
|
||||||
|
|
||||||
// deaggregate records
|
// desegregate records
|
||||||
if c.isAggregated {
|
if c.isAggregated {
|
||||||
records, err = deaggregateRecords(resp.Records)
|
records, err = disaggregateRecords(resp.Records)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -199,11 +200,11 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
err := fn(&Record{r, shardID, resp.MillisBehindLatest})
|
err := fn(&Record{r, shardID, resp.MillisBehindLatest})
|
||||||
if err != nil && err != ErrSkipCheckpoint {
|
if err != nil && !errors.Is(err, ErrSkipCheckpoint) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != ErrSkipCheckpoint {
|
if !errors.Is(err, ErrSkipCheckpoint) {
|
||||||
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -240,14 +241,14 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// temporary conversion func of []types.Record -> DeaggregateRecords([]*types.Record) -> []types.Record
|
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
|
||||||
func deaggregateRecords(in []types.Record) ([]types.Record, error) {
|
func disaggregateRecords(in []types.Record) ([]types.Record, error) {
|
||||||
var recs []*types.Record
|
var recs []*types.Record
|
||||||
for _, rec := range in {
|
for _, rec := range in {
|
||||||
recs = append(recs, &rec)
|
recs = append(recs, &rec)
|
||||||
}
|
}
|
||||||
|
|
||||||
deagg, err := deaggregator.DeaggregateRecords(recs)
|
deagg, err := deaggregator.DisaggregatedRecords(recs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -272,7 +273,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
|
||||||
params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp
|
params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp
|
||||||
params.Timestamp = c.initialTimestamp
|
params.Timestamp = c.initialTimestamp
|
||||||
} else {
|
} else {
|
||||||
params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType)
|
params.ShardIteratorType = c.initialShardIteratorType
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := c.client.GetShardIterator(ctx, params)
|
res, err := c.client.GetShardIterator(ctx, params)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import (
|
||||||
rec "github.com/awslabs/kinesis-aggregation/go/records"
|
rec "github.com/awslabs/kinesis-aggregation/go/records"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Magic File Header for a KPL Aggregated Record
|
// KplMagicHeader Magic File Header for a KPL Aggregated Record
|
||||||
var KplMagicHeader = fmt.Sprintf("%q", []byte("\xf3\x89\x9a\xc2"))
|
var KplMagicHeader = fmt.Sprintf("%q", []byte("\xf3\x89\x9a\xc2"))
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -20,9 +20,9 @@ const (
|
||||||
DigestSize = 16 // MD5 Message size for protobuf.
|
DigestSize = 16 // MD5 Message size for protobuf.
|
||||||
)
|
)
|
||||||
|
|
||||||
// DeaggregateRecords takes an array of Kinesis records and expands any Protobuf
|
// DisaggregatedRecords takes an array of Kinesis records and expands any Protobuf
|
||||||
// records within that array, returning an array of all records
|
// records within that array, returning an array of all records
|
||||||
func DeaggregateRecords(records []*types.Record) ([]*types.Record, error) {
|
func DisaggregatedRecords(records []*types.Record) ([]*types.Record, error) {
|
||||||
var isAggregated bool
|
var isAggregated bool
|
||||||
allRecords := make([]*types.Record, 0)
|
allRecords := make([]*types.Record, 0)
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
|
|
@ -79,7 +79,7 @@ func DeaggregateRecords(records []*types.Record) ([]*types.Record, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// createUserRecord takes in the partitionKeys of the aggregated record, the individual
|
// createUserRecord takes in the partitionKeys of the aggregated record, the individual
|
||||||
// deaggregated record, and the original aggregated record builds a kinesis.Record and
|
// disaggregated record, and the original aggregated record builds a kinesis.Record and
|
||||||
// returns it
|
// returns it
|
||||||
func createUserRecord(partitionKeys []string, aggRec *rec.Record, record *types.Record) *types.Record {
|
func createUserRecord(partitionKeys []string, aggRec *rec.Record, record *types.Record) *types.Record {
|
||||||
partitionKey := partitionKeys[*aggRec.PartitionKeyIndex]
|
partitionKey := partitionKeys[*aggRec.PartitionKeyIndex]
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
rec "github.com/awslabs/kinesis-aggregation/go/records"
|
rec "github.com/awslabs/kinesis-aggregation/go/records"
|
||||||
|
|
||||||
deagg "github.com/harlow/kinesis-consumer/internal/deaggregator"
|
deagg "github.com/harlow/kinesis-consumer/internal/deaggregator"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -81,7 +82,7 @@ func TestSmallLengthReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) {
|
||||||
smallByte := []byte("No")
|
smallByte := []byte("No")
|
||||||
kr = generateKinesisRecord(smallByte)
|
kr = generateKinesisRecord(smallByte)
|
||||||
krs = append(krs, kr)
|
krs = append(krs, kr)
|
||||||
dars, err := deagg.DeaggregateRecords(krs)
|
dars, err := deagg.DisaggregatedRecords(krs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
@ -108,7 +109,7 @@ func TestNonMatchingMagicHeaderReturnsSingleRecord(t *testing.T) {
|
||||||
|
|
||||||
krs = append(krs, kr)
|
krs = append(krs, kr)
|
||||||
|
|
||||||
dars, err := deagg.DeaggregateRecords(krs)
|
dars, err := deagg.DisaggregatedRecords(krs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
@ -118,7 +119,7 @@ func TestNonMatchingMagicHeaderReturnsSingleRecord(t *testing.T) {
|
||||||
assert.Equal(t, 1, len(dars), "Mismatch magic header test should return length of 1.")
|
assert.Equal(t, 1, len(dars), "Mismatch magic header test should return length of 1.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function tests that the DeaggregateRecords function returns the correct number of
|
// This function tests that the DisaggregatedRecords function returns the correct number of
|
||||||
// deaggregated records from a single aggregated record.
|
// deaggregated records from a single aggregated record.
|
||||||
func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) {
|
func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
|
|
@ -133,7 +134,7 @@ func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testi
|
||||||
kr = generateKinesisRecord(aggData)
|
kr = generateKinesisRecord(aggData)
|
||||||
krs = append(krs, kr)
|
krs = append(krs, kr)
|
||||||
|
|
||||||
dars, err := deagg.DeaggregateRecords(krs)
|
dars, err := deagg.DisaggregatedRecords(krs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
@ -162,13 +163,13 @@ func TestRecordAfterMagicHeaderWithLengthLessThanDigestSizeReturnsSingleRecord(t
|
||||||
|
|
||||||
krs = append(krs, kr)
|
krs = append(krs, kr)
|
||||||
|
|
||||||
dars, err := deagg.DeaggregateRecords(krs)
|
dars, err := deagg.DisaggregatedRecords(krs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// A byte record with length less than 16 after the magic header should return
|
// A byte record with length less than 16 after the magic header should return
|
||||||
// a single record from DeaggregateRecords
|
// a single record from DisaggregatedRecords
|
||||||
assert.Equal(t, 1, len(dars), "Digest size test should return length of 1.")
|
assert.Equal(t, 1, len(dars), "Digest size test should return length of 1.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -191,7 +192,7 @@ func TestRecordWithMismatchMd5SumReturnsSingleRecord(t *testing.T) {
|
||||||
|
|
||||||
krs = append(krs, kr)
|
krs = append(krs, kr)
|
||||||
|
|
||||||
dars, err := deagg.DeaggregateRecords(krs)
|
dars, err := deagg.DisaggregatedRecords(krs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Logger is a minimal interface to as a adaptor for external logging library to consumer
|
// A Logger is a minimal interface to as an adaptor for external logging library to consumer
|
||||||
type Logger interface {
|
type Logger interface {
|
||||||
Log(...interface{})
|
Log(...interface{})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue