#9 fixes linting issues

This commit is contained in:
Alex Senger 2024-04-10 16:45:34 +02:00
parent eaf4defe57
commit f0acb329f7
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
13 changed files with 84 additions and 78 deletions

View file

@ -243,20 +243,16 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
// temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record // temporary conversion func of []types.Record -> DesegregateRecords([]*types.Record) -> []types.Record
func disaggregateRecords(in []types.Record) ([]types.Record, error) { func disaggregateRecords(in []types.Record) ([]types.Record, error) {
var recs []*types.Record var recs []types.Record
for _, rec := range in { recs = append(recs, in...)
recs = append(recs, &rec)
}
deagg, err := deaggregator.DisaggregatedRecords(recs) deagg, err := deaggregator.DeaggregateRecords(recs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var out []types.Record var out []types.Record
for _, rec := range deagg { out = append(out, deagg...)
out = append(out, *rec)
}
return out, nil return out, nil
} }

View file

@ -32,18 +32,18 @@ func TestNew(t *testing.T) {
func TestScan(t *testing.T) { func TestScan(t *testing.T) {
client := &kinesisClientMock{ client := &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: records, Records: records,
}, nil }, nil
}, },
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { listShardsMock: func(_ context.Context, _ *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{ return &kinesis.ListShardsOutput{
Shards: []types.Shard{ Shards: []types.Shard{
{ShardId: aws.String("myShard")}, {ShardId: aws.String("myShard")},
@ -100,12 +100,12 @@ func TestScan(t *testing.T) {
func TestScanShard(t *testing.T) { func TestScanShard(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: records, Records: records,
@ -166,12 +166,12 @@ func TestScanShard(t *testing.T) {
func TestScanShard_Cancellation(t *testing.T) { func TestScanShard_Cancellation(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: records, Records: records,
@ -206,12 +206,12 @@ func TestScanShard_Cancellation(t *testing.T) {
func TestScanShard_SkipCheckpoint(t *testing.T) { func TestScanShard_SkipCheckpoint(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: records, Records: records,
@ -250,12 +250,12 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
func TestScanShard_ShardIsClosed(t *testing.T) { func TestScanShard_ShardIsClosed(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: make([]types.Record, 0), Records: make([]types.Record, 0),
@ -268,7 +268,7 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
var fn = func(r *Record) error { var fn = func(_ *Record) error {
return nil return nil
} }
@ -280,12 +280,12 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) { func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: make([]types.Record, 0), Records: make([]types.Record, 0),
@ -293,13 +293,13 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) {
}, },
} }
var fn = func(r *Record) error { var fn = func(_ *Record) error {
return nil return nil
} }
c, err := New("myStreamName", c, err := New("myStreamName",
WithClient(client), WithClient(client),
WithShardClosedHandler(func(streamName, shardID string) error { WithShardClosedHandler(func(_, _ string) error {
return fmt.Errorf("closed shard error") return fmt.Errorf("closed shard error")
})) }))
if err != nil { if err != nil {
@ -317,12 +317,12 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) {
func TestScanShard_GetRecordsError(t *testing.T) { func TestScanShard_GetRecordsError(t *testing.T) {
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(_ context.Context, _ *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(_ context.Context, _ *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: nil, Records: nil,
@ -331,7 +331,7 @@ func TestScanShard_GetRecordsError(t *testing.T) {
}, },
} }
var fn = func(r *Record) error { var fn = func(_ *Record) error {
return nil return nil
} }
@ -353,15 +353,15 @@ type kinesisClientMock struct {
listShardsMock func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) listShardsMock func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error)
} }
func (c *kinesisClientMock) ListShards(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { func (c *kinesisClientMock) ListShards(ctx context.Context, params *kinesis.ListShardsInput, _ ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return c.listShardsMock(ctx, params) return c.listShardsMock(ctx, params)
} }
func (c *kinesisClientMock) GetRecords(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { func (c *kinesisClientMock) GetRecords(ctx context.Context, params *kinesis.GetRecordsInput, _ ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) {
return c.getRecordsMock(ctx, params) return c.getRecordsMock(ctx, params)
} }
func (c *kinesisClientMock) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { func (c *kinesisClientMock) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, _ ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return c.getShardIteratorMock(ctx, params) return c.getShardIteratorMock(ctx, params)
} }
@ -378,7 +378,7 @@ func (fc *fakeCounter) Get() int64 {
return fc.counter return fc.counter
} }
func (fc *fakeCounter) Add(streamName string, count int64) { func (fc *fakeCounter) Add(_ string, count int64) {
fc.mu.Lock() fc.mu.Lock()
defer fc.mu.Unlock() defer fc.mu.Unlock()

View file

@ -74,19 +74,21 @@ func WithMaxRecords(n int64) Option {
} }
} }
// WithAggregation overrides the default option for aggregating records
func WithAggregation(a bool) Option { func WithAggregation(a bool) Option {
return func(c *Consumer) { return func(c *Consumer) {
c.isAggregated = a c.isAggregated = a
} }
} }
// WithShardClosedHandler defines a custom handler for closed shards.
func WithShardClosedHandler(h ShardClosedHandler) Option {
return func(c *Consumer) {
c.shardClosedHandler = h
}
}
// ShardClosedHandler is a handler that will be called when the consumer has reached the end of a closed shard. // ShardClosedHandler is a handler that will be called when the consumer has reached the end of a closed shard.
// No more records for that shard will be provided by the consumer. // No more records for that shard will be provided by the consumer.
// An error can be returned to stop the consumer. // An error can be returned to stop the consumer.
type ShardClosedHandler = func(streamName, shardID string) error type ShardClosedHandler = func(streamName, shardID string) error
func WithShardClosedHandler(h ShardClosedHandler) Option {
return func(c *Consumer) {
c.shardClosedHandler = h
}
}

View file

@ -43,7 +43,7 @@ func New(appName, tableName string, opts ...Option) (*Checkpoint, error) {
ck := &Checkpoint{ ck := &Checkpoint{
tableName: tableName, tableName: tableName,
appName: appName, appName: appName,
maxInterval: time.Duration(1 * time.Minute), maxInterval: 1 * time.Minute,
done: make(chan struct{}), done: make(chan struct{}),
mu: &sync.Mutex{}, mu: &sync.Mutex{},
checkpoints: map[key]string{}, checkpoints: map[key]string{},
@ -68,7 +68,7 @@ func New(appName, tableName string, opts ...Option) (*Checkpoint, error) {
return ck, nil return ck, nil
} }
// Checkpoint stores and retreives the last evaluated key from a DDB scan // Checkpoint stores and retrieves the last evaluated key from a DDB scan
type Checkpoint struct { type Checkpoint struct {
tableName string tableName string
appName string appName string
@ -115,12 +115,12 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
} }
var i item var i item
attributevalue.UnmarshalMap(resp.Item, &i) _ = attributevalue.UnmarshalMap(resp.Item, &i)
return i.SequenceNumber, nil return i.SequenceNumber, nil
} }
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point. // Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -152,7 +152,7 @@ func (c *Checkpoint) loop() {
for { for {
select { select {
case <-tick.C: case <-tick.C:
c.save() _ = c.save()
case <-c.done: case <-c.done:
return return
} }

View file

@ -14,7 +14,7 @@ type fakeRetryer struct {
Name string Name string
} }
func (r *fakeRetryer) ShouldRetry(err error) bool { func (r *fakeRetryer) ShouldRetry(_ error) bool {
r.Name = "fakeRetryer" r.Name = "fakeRetryer"
return false return false
} }
@ -34,7 +34,7 @@ func TestCheckpointSetting(t *testing.T) {
ckPtr := &ck ckPtr := &ck
// Test WithMaxInterval // Test WithMaxInterval
setInterval := WithMaxInterval(time.Duration(2 * time.Minute)) setInterval := WithMaxInterval(2 * time.Minute)
setInterval(ckPtr) setInterval(ckPtr)
// Test WithRetryer // Test WithRetryer
@ -52,7 +52,7 @@ func TestCheckpointSetting(t *testing.T) {
setDDBClient := WithDynamoClient(fakeDbClient) setDDBClient := WithDynamoClient(fakeDbClient)
setDDBClient(ckPtr) setDDBClient(ckPtr)
if ckPtr.maxInterval != time.Duration(2*time.Minute) { if ckPtr.maxInterval != 2*time.Minute {
t.Errorf("new checkpoint maxInterval expected 2 minute. got %v", ckPtr.maxInterval) t.Errorf("new checkpoint maxInterval expected 2 minute. got %v", ckPtr.maxInterval)
} }
if ckPtr.retryer.ShouldRetry(nil) != false { if ckPtr.retryer.ShouldRetry(nil) != false {
@ -65,7 +65,7 @@ func TestCheckpointSetting(t *testing.T) {
func TestNewCheckpointWithOptions(t *testing.T) { func TestNewCheckpointWithOptions(t *testing.T) {
// Test WithMaxInterval // Test WithMaxInterval
setInterval := WithMaxInterval(time.Duration(2 * time.Minute)) setInterval := WithMaxInterval(2 * time.Minute)
// Test WithRetryer // Test WithRetryer
var r fakeRetryer var r fakeRetryer
@ -94,7 +94,7 @@ func TestNewCheckpointWithOptions(t *testing.T) {
t.Errorf("new checkpoint table expected %v. got %v", "testtable", ckPtr.maxInterval) t.Errorf("new checkpoint table expected %v. got %v", "testtable", ckPtr.maxInterval)
} }
if ckPtr.maxInterval != time.Duration(2*time.Minute) { if ckPtr.maxInterval != 2*time.Minute {
t.Errorf("new checkpoint maxInterval expected 2 minute. got %v", ckPtr.maxInterval) t.Errorf("new checkpoint maxInterval expected 2 minute. got %v", ckPtr.maxInterval)
} }
if ckPtr.retryer.ShouldRetry(nil) != false { if ckPtr.retryer.ShouldRetry(nil) != false {
@ -103,5 +103,4 @@ func TestNewCheckpointWithOptions(t *testing.T) {
if ckPtr.client != fakeDbClient { if ckPtr.client != fakeDbClient {
t.Errorf("new checkpoint dynamodb client reference should be %p. got %v", &fakeDbClient, ckPtr.client) t.Errorf("new checkpoint dynamodb client reference should be %p. got %v", &fakeDbClient, ckPtr.client)
} }
} }

View file

@ -14,7 +14,7 @@ type DefaultRetryer struct {
Retryer Retryer
} }
// ShouldRetry when error occured // ShouldRetry when error occurred
func (r *DefaultRetryer) ShouldRetry(err error) bool { func (r *DefaultRetryer) ShouldRetry(err error) bool {
switch err.(type) { switch err.(type) {
case *types.ProvisionedThroughputExceededException: case *types.ProvisionedThroughputExceededException:

View file

@ -1,3 +1,5 @@
// Package store
//
// The memory store provides a store that can be used for testing and single-threaded applications. // The memory store provides a store that can be used for testing and single-threaded applications.
// DO NOT USE this in a production application where persistence beyond a single application lifecycle is necessary // DO NOT USE this in a production application where persistence beyond a single application lifecycle is necessary
// or when there are multiple consumers. // or when there are multiple consumers.
@ -8,14 +10,17 @@ import (
"sync" "sync"
) )
// New returns a new in memory store to persist the last consumed offset.
func New() *Store { func New() *Store {
return &Store{} return &Store{}
} }
// Store is the in-memory data structure that holds the offsets per stream
type Store struct { type Store struct {
sync.Map sync.Map
} }
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" { if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty") return fmt.Errorf("sequence number should not be empty")
@ -24,6 +29,9 @@ func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error
return nil return nil
} }
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically, this is used to determine whether processing should start with TRIM_HORIZON or AFTER_SEQUENCE_NUMBER
// (if checkpoint exists).
func (c *Store) GetCheckpoint(streamName, shardID string) (string, error) { func (c *Store) GetCheckpoint(streamName, shardID string) (string, error) {
val, ok := c.Load(streamName + ":" + shardID) val, ok := c.Load(streamName + ":" + shardID)
if !ok { if !ok {

View file

@ -8,7 +8,7 @@ func Test_CheckpointLifecycle(t *testing.T) {
c := New() c := New()
// set // set
c.SetCheckpoint("streamName", "shardID", "testSeqNum") _ = c.SetCheckpoint("streamName", "shardID", "testSeqNum")
// get // get
val, err := c.GetCheckpoint("streamName", "shardID") val, err := c.GetCheckpoint("streamName", "shardID")

View file

@ -7,6 +7,7 @@ import (
"sync" "sync"
"time" "time"
// this is the mysql package, so it makes sense to be here
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
) )
@ -84,11 +85,11 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName) namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string var sequenceNumber string
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) //nolint: gas, it replaces only the table name getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) // nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return "", nil return "", nil
} }
return "", err return "", err
@ -98,7 +99,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
} }
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point. // Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -134,7 +135,7 @@ func (c *Checkpoint) loop() {
for { for {
select { select {
case <-tick.C: case <-tick.C:
c.save() _ = c.save()
case <-c.done: case <-c.done:
return return
} }
@ -145,7 +146,7 @@ func (c *Checkpoint) save() error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
//nolint: gas, it replaces only the table name // nolint: gas, it replaces only the table name
upsertCheckpoint := fmt.Sprintf(`REPLACE INTO %s (namespace, shard_id, sequence_number) VALUES (?, ?, ?)`, c.tableName) upsertCheckpoint := fmt.Sprintf(`REPLACE INTO %s (namespace, shard_id, sequence_number) VALUES (?, ?, ?)`, c.tableName)
for key, sequenceNumber := range c.checkpoints { for key, sequenceNumber := range c.checkpoints {

View file

@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -22,7 +22,7 @@ func TestNew(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestNew_AppNameEmpty(t *testing.T) { func TestNew_AppNameEmpty(t *testing.T) {
@ -69,7 +69,7 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_GetCheckpoint(t *testing.T) { func TestCheckpoint_GetCheckpoint(t *testing.T) {
@ -109,7 +109,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_Get_NoRows(t *testing.T) { func TestCheckpoint_Get_NoRows(t *testing.T) {
@ -145,7 +145,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_Get_QueryError(t *testing.T) { func TestCheckpoint_Get_QueryError(t *testing.T) {
@ -181,7 +181,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_SetCheckpoint(t *testing.T) { func TestCheckpoint_SetCheckpoint(t *testing.T) {
@ -202,7 +202,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
@ -223,7 +223,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("expected error equals not nil, but got %v", err) t.Errorf("expected error equals not nil, but got %v", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_Shutdown(t *testing.T) { func TestCheckpoint_Shutdown(t *testing.T) {

View file

@ -7,7 +7,7 @@ import (
"sync" "sync"
"time" "time"
// this is the postgres package so it makes sense to be here // this is the postgres package, so it makes sense to be here
_ "github.com/lib/pq" _ "github.com/lib/pq"
) )
@ -85,11 +85,11 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName) namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string var sequenceNumber string
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) //nolint: gas, it replaces only the table name getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) // nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return "", nil return "", nil
} }
return "", err return "", err
@ -99,7 +99,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
} }
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point. // Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -135,7 +135,7 @@ func (c *Checkpoint) loop() {
for { for {
select { select {
case <-tick.C: case <-tick.C:
c.save() _ = c.save()
case <-c.done: case <-c.done:
return return
} }
@ -146,7 +146,7 @@ func (c *Checkpoint) save() error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
//nolint: gas, it replaces only the table name // nolint: gas, it replaces only the table name
upsertCheckpoint := fmt.Sprintf(`INSERT INTO %s (namespace, shard_id, sequence_number) upsertCheckpoint := fmt.Sprintf(`INSERT INTO %s (namespace, shard_id, sequence_number)
VALUES($1, $2, $3) VALUES($1, $2, $3)
ON CONFLICT (namespace, shard_id) ON CONFLICT (namespace, shard_id)

View file

@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -22,7 +22,7 @@ func TestNew(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestNew_AppNameEmpty(t *testing.T) { func TestNew_AppNameEmpty(t *testing.T) {
@ -69,7 +69,7 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_GetCheckpoint(t *testing.T) { func TestCheckpoint_GetCheckpoint(t *testing.T) {
@ -109,7 +109,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_Get_NoRows(t *testing.T) { func TestCheckpoint_Get_NoRows(t *testing.T) {
@ -145,7 +145,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_Get_QueryError(t *testing.T) { func TestCheckpoint_Get_QueryError(t *testing.T) {
@ -181,7 +181,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_SetCheckpoint(t *testing.T) { func TestCheckpoint_SetCheckpoint(t *testing.T) {
@ -202,7 +202,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
@ -223,7 +223,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("expected error equals not nil, but got %v", err) t.Errorf("expected error equals not nil, but got %v", err)
} }
ck.Shutdown() _ = ck.Shutdown()
} }
func TestCheckpoint_Shutdown(t *testing.T) { func TestCheckpoint_Shutdown(t *testing.T) {

View file

@ -32,7 +32,7 @@ func Test_CheckpointLifecycle(t *testing.T) {
} }
// set // set
c.SetCheckpoint("streamName", "shardID", "testSeqNum") _ = c.SetCheckpoint("streamName", "shardID", "testSeqNum")
// get // get
val, err := c.GetCheckpoint("streamName", "shardID") val, err := c.GetCheckpoint("streamName", "shardID")