Merge pull request #201 from alexgridx/26-thread-context-through-to-stores

#23 threads context through to stores
This commit is contained in:
Alex 2024-09-19 10:30:44 +02:00 committed by GitHub
commit e09e158483
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 98 additions and 87 deletions

View file

@ -155,7 +155,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
// get last seq number from checkpoint // get last seq number from checkpoint
lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID) lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
if err != nil { if err != nil {
return fmt.Errorf("get checkpoint error: %w", err) return fmt.Errorf("get checkpoint error: %w", err)
} }
@ -223,7 +223,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
} }
if !errors.Is(err, ErrSkipCheckpoint) { if !errors.Is(err, ErrSkipCheckpoint) {
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
return err return err
} }
c.counter.Add("checkpoint", 1) c.counter.Add("checkpoint", 1)

View file

@ -7,14 +7,13 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"log/slog"
"net" "net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"time" "time"
alog "github.com/apex/log"
"github.com/apex/log/handlers/text"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb"
@ -38,25 +37,7 @@ func init() {
}() }()
} }
// A myLogger provides a minimalistic logger satisfying the Logger interface.
type myLogger struct {
logger alog.Logger
}
// Log logs the parameters to the stdlib logger. See log.Println.
func (l *myLogger) Log(args ...interface{}) {
l.logger.Infof("producer: %v", args...)
}
func main() { func main() {
// Wrap myLogger around apex logger
myLog := &myLogger{
logger: alog.Logger{
Handler: text.New(os.Stdout),
Level: alog.DebugLevel,
},
}
var ( var (
app = flag.String("app", "", "Consumer app name") app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name") stream = flag.String("stream", "", "Stream name")
@ -100,7 +81,7 @@ func main() {
c, err := consumer.New( c, err := consumer.New(
*stream, *stream,
consumer.WithStore(ddb), consumer.WithStore(ddb),
consumer.WithLogger(myLog), consumer.WithLogger(slog.Default()),
consumer.WithCounter(counter), consumer.WithCounter(counter),
consumer.WithClient(client), consumer.WithClient(client),
) )
@ -129,7 +110,7 @@ func main() {
log.Fatalf("scan error: %v", err) log.Fatalf("scan error: %v", err)
} }
if err := ddb.Shutdown(); err != nil { if err := ddb.Shutdown(ctx); err != nil {
log.Fatalf("storage shutdown error: %v", err) log.Fatalf("storage shutdown error: %v", err)
} }
} }

View file

@ -85,7 +85,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if err := checkpointStore.Shutdown(); err != nil { if err := checkpointStore.Shutdown(context.Background()); err != nil {
slog.Error("store shutdown error", slog.String("error", err.Error())) slog.Error("store shutdown error", slog.String("error", err.Error()))
os.Exit(1) os.Exit(1)
} }

View file

@ -9,6 +9,6 @@ import (
// Group interface used to manage which shard to process // Group interface used to manage which shard to process
type Group interface { type Group interface {
Start(ctx context.Context, shardc chan types.Shard) Start(ctx context.Context, shardc chan types.Shard)
GetCheckpoint(streamName, shardID string) (string, error) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error
} }

View file

@ -1,13 +1,17 @@
package consumer package consumer
import (
"context"
)
// Store interface used to persist scan progress // Store interface used to persist scan progress
type Store interface { type Store interface {
GetCheckpoint(streamName, shardID string) (string, error) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error
} }
// noopStore implements the storage interface with discard // noopStore implements the storage interface with discard
type noopStore struct{} type noopStore struct{}
func (n noopStore) GetCheckpoint(string, string) (string, error) { return "", nil } func (n noopStore) GetCheckpoint(context.Context, string, string) (string, error) { return "", nil }
func (n noopStore) SetCheckpoint(string, string, string) error { return nil } func (n noopStore) SetCheckpoint(context.Context, string, string, string) error { return nil }

View file

@ -94,7 +94,7 @@ type item struct {
// GetCheckpoint determines if a checkpoint for a particular Shard exists. // GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with // Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName) namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
params := &dynamodb.GetItemInput{ params := &dynamodb.GetItemInput{
@ -106,10 +106,10 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
}, },
} }
resp, err := c.client.GetItem(context.Background(), params) resp, err := c.client.GetItem(ctx, params)
if err != nil { if err != nil {
if c.retryer.ShouldRetry(err) { if c.retryer.ShouldRetry(err) {
return c.GetCheckpoint(streamName, shardID) return c.GetCheckpoint(ctx, streamName, shardID)
} }
return "", err return "", err
} }
@ -121,7 +121,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon fail over, record processing is resumed from this point. // Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -139,12 +139,13 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e
} }
// Shutdown the checkpoint. Save any in-flight data. // Shutdown the checkpoint. Save any in-flight data.
func (c *Checkpoint) Shutdown() error { func (c *Checkpoint) Shutdown(ctx context.Context) error {
c.done <- struct{}{} c.done <- struct{}{}
return c.save() return c.save(ctx)
} }
func (c *Checkpoint) loop() { func (c *Checkpoint) loop() {
ctx := context.Background()
tick := time.NewTicker(c.maxInterval) tick := time.NewTicker(c.maxInterval)
defer tick.Stop() defer tick.Stop()
defer close(c.done) defer close(c.done)
@ -152,14 +153,14 @@ func (c *Checkpoint) loop() {
for { for {
select { select {
case <-tick.C: case <-tick.C:
_ = c.save() _ = c.save(ctx)
case <-c.done: case <-c.done:
return return
} }
} }
} }
func (c *Checkpoint) save() error { func (c *Checkpoint) save(ctx context.Context) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -175,7 +176,7 @@ func (c *Checkpoint) save() error {
} }
_, err = c.client.PutItem( _, err = c.client.PutItem(
context.TODO(), ctx,
&dynamodb.PutItemInput{ &dynamodb.PutItemInput{
TableName: aws.String(c.tableName), TableName: aws.String(c.tableName),
Item: item, Item: item,
@ -184,7 +185,7 @@ func (c *Checkpoint) save() error {
if !c.retryer.ShouldRetry(err) { if !c.retryer.ShouldRetry(err) {
return err return err
} }
return c.save() return c.save(ctx)
} }
} }

View file

@ -6,6 +6,7 @@
package store package store
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
) )
@ -21,7 +22,7 @@ type Store struct {
} }
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Store) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" { if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty") return fmt.Errorf("sequence number should not be empty")
} }
@ -32,7 +33,7 @@ func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error
// GetCheckpoint determines if a checkpoint for a particular Shard exists. // GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically, this is used to determine whether processing should start with TRIM_HORIZON or AFTER_SEQUENCE_NUMBER // Typically, this is used to determine whether processing should start with TRIM_HORIZON or AFTER_SEQUENCE_NUMBER
// (if checkpoint exists). // (if checkpoint exists).
func (c *Store) GetCheckpoint(streamName, shardID string) (string, error) { func (c *Store) GetCheckpoint(_ context.Context, streamName, shardID string) (string, error) {
val, ok := c.Load(streamName + ":" + shardID) val, ok := c.Load(streamName + ":" + shardID)
if !ok { if !ok {
return "", nil return "", nil

View file

@ -1,17 +1,19 @@
package store package store
import ( import (
"context"
"testing" "testing"
) )
func Test_CheckpointLifecycle(t *testing.T) { func Test_CheckpointLifecycle(t *testing.T) {
c := New() c := New()
ctx := context.Background()
// set // set
_ = c.SetCheckpoint("streamName", "shardID", "testSeqNum") _ = c.SetCheckpoint(ctx, "streamName", "shardID", "testSeqNum")
// get // get
val, err := c.GetCheckpoint("streamName", "shardID") val, err := c.GetCheckpoint(ctx, "streamName", "shardID")
if err != nil { if err != nil {
t.Fatalf("get checkpoint error: %v", err) t.Fatalf("get checkpoint error: %v", err)
} }
@ -22,8 +24,9 @@ func Test_CheckpointLifecycle(t *testing.T) {
func Test_SetEmptySeqNum(t *testing.T) { func Test_SetEmptySeqNum(t *testing.T) {
c := New() c := New()
ctx := context.Background()
err := c.SetCheckpoint("streamName", "shardID", "") err := c.SetCheckpoint(ctx, "streamName", "shardID", "")
if err == nil || err.Error() != "sequence number should not be empty" { if err == nil || err.Error() != "sequence number should not be empty" {
t.Fatalf("should not allow empty sequence number") t.Fatalf("should not allow empty sequence number")
} }

View file

@ -1,6 +1,7 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -81,12 +82,12 @@ func (c *Checkpoint) GetMaxInterval() time.Duration {
// GetCheckpoint determines if a checkpoint for a particular Shard exists. // GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with // Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName) namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string var sequenceNumber string
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) // nolint: gas, it replaces only the table name getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) // nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) err := c.conn.QueryRowContext(ctx, getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@ -100,7 +101,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon fail over, record processing is resumed from this point. // Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()

View file

@ -1,6 +1,7 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"testing" "testing"
@ -73,6 +74,7 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
} }
func TestCheckpoint_GetCheckpoint(t *testing.T) { func TestCheckpoint_GetCheckpoint(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "user:password@/dbname" connString := "user:password@/dbname"
@ -98,7 +100,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != expectedSequenceNumber { if gotSequenceNumber != expectedSequenceNumber {
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber) t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
@ -113,6 +115,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
} }
func TestCheckpoint_Get_NoRows(t *testing.T) { func TestCheckpoint_Get_NoRows(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "user:password@/dbname" connString := "user:password@/dbname"
@ -134,7 +137,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != "" { if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -149,6 +152,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
} }
func TestCheckpoint_Get_QueryError(t *testing.T) { func TestCheckpoint_Get_QueryError(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "user:password@/dbname" connString := "user:password@/dbname"
@ -170,7 +174,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error")) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != "" { if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -185,6 +189,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
} }
func TestCheckpoint_SetCheckpoint(t *testing.T) { func TestCheckpoint_SetCheckpoint(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "user:password@/dbname" connString := "user:password@/dbname"
@ -197,7 +202,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
} }
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
@ -206,6 +211,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
} }
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "user:password@/dbname" connString := "user:password@/dbname"
@ -218,7 +224,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
} }
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err == nil { if err == nil {
t.Errorf("expected error equals not nil, but got %v", err) t.Errorf("expected error equals not nil, but got %v", err)
@ -227,6 +233,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
} }
func TestCheckpoint_Shutdown(t *testing.T) { func TestCheckpoint_Shutdown(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "user:password@/dbname" connString := "user:password@/dbname"
@ -249,7 +256,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
result := sqlmock.NewResult(0, 1) result := sqlmock.NewResult(0, 1)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
@ -266,6 +273,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
} }
func TestCheckpoint_Shutdown_SaveError(t *testing.T) { func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "user:password@/dbname" connString := "user:password@/dbname"
@ -287,7 +295,7 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName) expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error")) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)

View file

@ -1,6 +1,7 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -88,12 +89,12 @@ func (c *Checkpoint) GetMaxInterval() time.Duration {
// GetCheckpoint determines if a checkpoint for a particular Shard exists. // GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with // Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName) namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string var sequenceNumber string
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) // nolint: gas, it replaces only the table name getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) // nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) err := c.conn.QueryRowContext(ctx, getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@ -107,7 +108,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon fail over, record processing is resumed from this point. // Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -126,15 +127,16 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e
} }
// Shutdown the checkpoint. Save any in-flight data. // Shutdown the checkpoint. Save any in-flight data.
func (c *Checkpoint) Shutdown() error { func (c *Checkpoint) Shutdown(ctx context.Context) error {
defer c.conn.Close() defer c.conn.Close()
c.done <- struct{}{} c.done <- struct{}{}
return c.save() return c.save(ctx)
} }
func (c *Checkpoint) loop() { func (c *Checkpoint) loop() {
ctx := context.Background()
tick := time.NewTicker(c.maxInterval) tick := time.NewTicker(c.maxInterval)
defer tick.Stop() defer tick.Stop()
defer close(c.done) defer close(c.done)
@ -142,14 +144,14 @@ func (c *Checkpoint) loop() {
for { for {
select { select {
case <-tick.C: case <-tick.C:
_ = c.save() _ = c.save(ctx)
case <-c.done: case <-c.done:
return return
} }
} }
} }
func (c *Checkpoint) save() error { func (c *Checkpoint) save(ctx context.Context) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -159,10 +161,10 @@ func (c *Checkpoint) save() error {
ON CONFLICT (namespace, shard_id) ON CONFLICT (namespace, shard_id)
DO DO
UPDATE UPDATE
SET sequence_number= $3;`, c.tableName) SET sequence_number=$3;`, c.tableName)
for key, sequenceNumber := range c.checkpoints { for key, sequenceNumber := range c.checkpoints {
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil { if _, err := c.conn.ExecContext(ctx, upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
return err return err
} }
} }

View file

@ -3,6 +3,7 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"testing" "testing"
@ -13,6 +14,7 @@ import (
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -24,7 +26,7 @@ func TestNew(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
_ = ck.Shutdown() _ = ck.Shutdown(ctx)
} }
func TestNew_AppNameEmpty(t *testing.T) { func TestNew_AppNameEmpty(t *testing.T) {
@ -56,6 +58,7 @@ func TestNew_TableNameEmpty(t *testing.T) {
} }
func TestNew_WithMaxIntervalOption(t *testing.T) { func TestNew_WithMaxIntervalOption(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -71,10 +74,11 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
_ = ck.Shutdown() _ = ck.Shutdown(ctx)
} }
func TestCheckpoint_GetCheckpoint(t *testing.T) { func TestCheckpoint_GetCheckpoint(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -100,7 +104,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != expectedSequenceNumber { if gotSequenceNumber != expectedSequenceNumber {
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber) t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
@ -111,10 +115,11 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
_ = ck.Shutdown() _ = ck.Shutdown(ctx)
} }
func TestCheckpoint_Get_NoRows(t *testing.T) { func TestCheckpoint_Get_NoRows(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -136,7 +141,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != "" { if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -151,6 +156,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
} }
func TestCheckpoint_Get_QueryError(t *testing.T) { func TestCheckpoint_Get_QueryError(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -172,7 +178,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error")) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != "" { if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -183,10 +189,11 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
_ = ck.Shutdown() _ = ck.Shutdown(ctx)
} }
func TestCheckpoint_SetCheckpoint(t *testing.T) { func TestCheckpoint_SetCheckpoint(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -199,15 +206,16 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
} }
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
} }
_ = ck.Shutdown() _ = ck.Shutdown(ctx)
} }
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -220,15 +228,16 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
} }
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err == nil { if err == nil {
t.Errorf("expected error equals not nil, but got %v", err) t.Errorf("expected error equals not nil, but got %v", err)
} }
_ = ck.Shutdown() _ = ck.Shutdown(ctx)
} }
func TestCheckpoint_Shutdown(t *testing.T) { func TestCheckpoint_Shutdown(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -251,13 +260,13 @@ func TestCheckpoint_Shutdown(t *testing.T) {
result := sqlmock.NewResult(0, 1) result := sqlmock.NewResult(0, 1)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
} }
err = ck.Shutdown() err = ck.Shutdown(ctx)
if err != nil { if err != nil {
t.Errorf("expected error equals not nil, but got %v", err) t.Errorf("expected error equals not nil, but got %v", err)
@ -268,6 +277,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
} }
func TestCheckpoint_Shutdown_SaveError(t *testing.T) { func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -289,13 +299,13 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName) expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error")) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
} }
err = ck.Shutdown() err = ck.Shutdown(ctx)
if err == nil { if err == nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)

View file

@ -52,19 +52,17 @@ type Checkpoint struct {
} }
// GetCheckpoint fetches the checkpoint for a particular Shard. // GetCheckpoint fetches the checkpoint for a particular Shard.
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
ctx := context.Background()
val, _ := c.client.Get(ctx, c.key(streamName, shardID)).Result() val, _ := c.client.Get(ctx, c.key(streamName, shardID)).Result()
return val, nil return val, nil
} }
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon fail over, record processing is resumed from this point. // Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" { if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty") return fmt.Errorf("sequence number should not be empty")
} }
ctx := context.Background()
err := c.client.Set(ctx, c.key(streamName, shardID), sequenceNumber, 0).Err() err := c.client.Set(ctx, c.key(streamName, shardID), sequenceNumber, 0).Err()
if err != nil { if err != nil {
return err return err

View file

@ -3,6 +3,7 @@
package redis package redis
import ( import (
"context"
"testing" "testing"
"github.com/alicebob/miniredis" "github.com/alicebob/miniredis"
@ -34,10 +35,10 @@ func Test_CheckpointLifecycle(t *testing.T) {
} }
// set // set
_ = c.SetCheckpoint("streamName", "shardID", "testSeqNum") _ = c.SetCheckpoint(context.Background(), "streamName", "shardID", "testSeqNum")
// get // get
val, err := c.GetCheckpoint("streamName", "shardID") val, err := c.GetCheckpoint(context.Background(), "streamName", "shardID")
if err != nil { if err != nil {
t.Fatalf("get checkpoint error: %v", err) t.Fatalf("get checkpoint error: %v", err)
} }
@ -52,7 +53,7 @@ func Test_SetEmptySeqNum(t *testing.T) {
t.Fatalf("new checkpoint error: %v", err) t.Fatalf("new checkpoint error: %v", err)
} }
err = c.SetCheckpoint("streamName", "shardID", "") err = c.SetCheckpoint(context.Background(), "streamName", "shardID", "")
if err == nil { if err == nil {
t.Fatalf("should not allow empty sequence number") t.Fatalf("should not allow empty sequence number")
} }

1
worker.go Normal file
View file

@ -0,0 +1 @@
package consumer