#23 threads context through to stores

This commit is contained in:
Alex Senger 2024-09-18 14:20:37 +02:00
parent 8c86963bd7
commit 8a4cb38940
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
15 changed files with 98 additions and 87 deletions

View file

@ -155,7 +155,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
// get last seq number from checkpoint
lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID)
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
if err != nil {
return fmt.Errorf("get checkpoint error: %w", err)
}
@ -223,7 +223,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
}
if !errors.Is(err, ErrSkipCheckpoint) {
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
return err
}
c.counter.Add("checkpoint", 1)

View file

@ -7,14 +7,13 @@ import (
"flag"
"fmt"
"log"
"log/slog"
"net"
"net/http"
"os"
"os/signal"
"time"
alog "github.com/apex/log"
"github.com/apex/log/handlers/text"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
@ -38,25 +37,7 @@ func init() {
}()
}
// A myLogger provides a minimalistic logger satisfying the Logger interface.
type myLogger struct {
logger alog.Logger
}
// Log logs the parameters to the stdlib logger. See log.Println.
func (l *myLogger) Log(args ...interface{}) {
l.logger.Infof("producer: %v", args...)
}
func main() {
// Wrap myLogger around apex logger
myLog := &myLogger{
logger: alog.Logger{
Handler: text.New(os.Stdout),
Level: alog.DebugLevel,
},
}
var (
app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name")
@ -100,7 +81,7 @@ func main() {
c, err := consumer.New(
*stream,
consumer.WithStore(ddb),
consumer.WithLogger(myLog),
consumer.WithLogger(slog.Default()),
consumer.WithCounter(counter),
consumer.WithClient(client),
)
@ -129,7 +110,7 @@ func main() {
log.Fatalf("scan error: %v", err)
}
if err := ddb.Shutdown(); err != nil {
if err := ddb.Shutdown(ctx); err != nil {
log.Fatalf("storage shutdown error: %v", err)
}
}

View file

@ -85,7 +85,7 @@ func main() {
os.Exit(1)
}
if err := checkpointStore.Shutdown(); err != nil {
if err := checkpointStore.Shutdown(context.Background()); err != nil {
slog.Error("store shutdown error", slog.String("error", err.Error()))
os.Exit(1)
}

View file

@ -9,6 +9,6 @@ import (
// Group interface used to manage which shard to process
type Group interface {
Start(ctx context.Context, shardc chan types.Shard)
GetCheckpoint(streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error
GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error)
SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error
}

View file

@ -1,13 +1,17 @@
package consumer
import (
"context"
)
// Store interface used to persist scan progress
type Store interface {
GetCheckpoint(streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error
GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error)
SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error
}
// noopStore implements the storage interface with discard
type noopStore struct{}
func (n noopStore) GetCheckpoint(string, string) (string, error) { return "", nil }
func (n noopStore) SetCheckpoint(string, string, string) error { return nil }
func (n noopStore) GetCheckpoint(context.Context, string, string) (string, error) { return "", nil }
func (n noopStore) SetCheckpoint(context.Context, string, string, string) error { return nil }

View file

@ -94,7 +94,7 @@ type item struct {
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
params := &dynamodb.GetItemInput{
@ -106,10 +106,10 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
},
}
resp, err := c.client.GetItem(context.Background(), params)
resp, err := c.client.GetItem(ctx, params)
if err != nil {
if c.retryer.ShouldRetry(err) {
return c.GetCheckpoint(streamName, shardID)
return c.GetCheckpoint(ctx, streamName, shardID)
}
return "", err
}
@ -121,7 +121,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
c.mu.Lock()
defer c.mu.Unlock()
@ -139,12 +139,13 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e
}
// Shutdown the checkpoint. Save any in-flight data.
func (c *Checkpoint) Shutdown() error {
func (c *Checkpoint) Shutdown(ctx context.Context) error {
c.done <- struct{}{}
return c.save()
return c.save(ctx)
}
func (c *Checkpoint) loop() {
ctx := context.Background()
tick := time.NewTicker(c.maxInterval)
defer tick.Stop()
defer close(c.done)
@ -152,14 +153,14 @@ func (c *Checkpoint) loop() {
for {
select {
case <-tick.C:
_ = c.save()
_ = c.save(ctx)
case <-c.done:
return
}
}
}
func (c *Checkpoint) save() error {
func (c *Checkpoint) save(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
@ -175,7 +176,7 @@ func (c *Checkpoint) save() error {
}
_, err = c.client.PutItem(
context.TODO(),
ctx,
&dynamodb.PutItemInput{
TableName: aws.String(c.tableName),
Item: item,
@ -184,7 +185,7 @@ func (c *Checkpoint) save() error {
if !c.retryer.ShouldRetry(err) {
return err
}
return c.save()
return c.save(ctx)
}
}

View file

@ -6,6 +6,7 @@
package store
import (
"context"
"fmt"
"sync"
)
@ -21,7 +22,7 @@ type Store struct {
}
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
func (c *Store) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
@ -32,7 +33,7 @@ func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically, this is used to determine whether processing should start with TRIM_HORIZON or AFTER_SEQUENCE_NUMBER
// (if checkpoint exists).
func (c *Store) GetCheckpoint(streamName, shardID string) (string, error) {
func (c *Store) GetCheckpoint(_ context.Context, streamName, shardID string) (string, error) {
val, ok := c.Load(streamName + ":" + shardID)
if !ok {
return "", nil

View file

@ -1,17 +1,19 @@
package store
import (
"context"
"testing"
)
func Test_CheckpointLifecycle(t *testing.T) {
c := New()
ctx := context.Background()
// set
_ = c.SetCheckpoint("streamName", "shardID", "testSeqNum")
_ = c.SetCheckpoint(ctx, "streamName", "shardID", "testSeqNum")
// get
val, err := c.GetCheckpoint("streamName", "shardID")
val, err := c.GetCheckpoint(ctx, "streamName", "shardID")
if err != nil {
t.Fatalf("get checkpoint error: %v", err)
}
@ -22,8 +24,9 @@ func Test_CheckpointLifecycle(t *testing.T) {
func Test_SetEmptySeqNum(t *testing.T) {
c := New()
ctx := context.Background()
err := c.SetCheckpoint("streamName", "shardID", "")
err := c.SetCheckpoint(ctx, "streamName", "shardID", "")
if err == nil || err.Error() != "sequence number should not be empty" {
t.Fatalf("should not allow empty sequence number")
}

View file

@ -1,6 +1,7 @@
package mysql
import (
"context"
"database/sql"
"errors"
"fmt"
@ -81,12 +82,12 @@ func (c *Checkpoint) GetMaxInterval() time.Duration {
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) // nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
err := c.conn.QueryRowContext(ctx, getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
@ -100,7 +101,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
c.mu.Lock()
defer c.mu.Unlock()

View file

@ -1,6 +1,7 @@
package mysql
import (
"context"
"database/sql"
"fmt"
"testing"
@ -73,6 +74,7 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
}
func TestCheckpoint_GetCheckpoint(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
@ -98,7 +100,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != expectedSequenceNumber {
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
@ -113,6 +115,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
}
func TestCheckpoint_Get_NoRows(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
@ -134,7 +137,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -149,6 +152,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
}
func TestCheckpoint_Get_QueryError(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
@ -170,7 +174,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -185,6 +189,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
}
func TestCheckpoint_SetCheckpoint(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
@ -197,7 +202,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
@ -206,6 +211,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
}
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
@ -218,7 +224,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
@ -227,6 +233,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
}
func TestCheckpoint_Shutdown(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
@ -249,7 +256,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
result := sqlmock.NewResult(0, 1)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
@ -266,6 +273,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
}
func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
@ -287,7 +295,7 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)

View file

@ -1,6 +1,7 @@
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
@ -88,12 +89,12 @@ func (c *Checkpoint) GetMaxInterval() time.Duration {
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) // nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
err := c.conn.QueryRowContext(ctx, getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
@ -107,7 +108,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
c.mu.Lock()
defer c.mu.Unlock()
@ -126,15 +127,16 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e
}
// Shutdown the checkpoint. Save any in-flight data.
func (c *Checkpoint) Shutdown() error {
func (c *Checkpoint) Shutdown(ctx context.Context) error {
defer c.conn.Close()
c.done <- struct{}{}
return c.save()
return c.save(ctx)
}
func (c *Checkpoint) loop() {
ctx := context.Background()
tick := time.NewTicker(c.maxInterval)
defer tick.Stop()
defer close(c.done)
@ -142,14 +144,14 @@ func (c *Checkpoint) loop() {
for {
select {
case <-tick.C:
_ = c.save()
_ = c.save(ctx)
case <-c.done:
return
}
}
}
func (c *Checkpoint) save() error {
func (c *Checkpoint) save(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
@ -159,10 +161,10 @@ func (c *Checkpoint) save() error {
ON CONFLICT (namespace, shard_id)
DO
UPDATE
SET sequence_number= $3;`, c.tableName)
SET sequence_number=$3;`, c.tableName)
for key, sequenceNumber := range c.checkpoints {
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
if _, err := c.conn.ExecContext(ctx, upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
return err
}
}

View file

@ -3,6 +3,7 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"testing"
@ -13,6 +14,7 @@ import (
)
func TestNew(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -24,7 +26,7 @@ func TestNew(t *testing.T) {
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
_ = ck.Shutdown()
_ = ck.Shutdown(ctx)
}
func TestNew_AppNameEmpty(t *testing.T) {
@ -56,6 +58,7 @@ func TestNew_TableNameEmpty(t *testing.T) {
}
func TestNew_WithMaxIntervalOption(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -71,10 +74,11 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
_ = ck.Shutdown()
_ = ck.Shutdown(ctx)
}
func TestCheckpoint_GetCheckpoint(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -100,7 +104,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != expectedSequenceNumber {
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
@ -111,10 +115,11 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
_ = ck.Shutdown()
_ = ck.Shutdown(ctx)
}
func TestCheckpoint_Get_NoRows(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -136,7 +141,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -151,6 +156,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
}
func TestCheckpoint_Get_QueryError(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -172,7 +178,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -183,10 +189,11 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
_ = ck.Shutdown()
_ = ck.Shutdown(ctx)
}
func TestCheckpoint_SetCheckpoint(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -199,15 +206,16 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
_ = ck.Shutdown()
_ = ck.Shutdown(ctx)
}
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -220,15 +228,16 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
_ = ck.Shutdown()
_ = ck.Shutdown(ctx)
}
func TestCheckpoint_Shutdown(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -251,13 +260,13 @@ func TestCheckpoint_Shutdown(t *testing.T) {
result := sqlmock.NewResult(0, 1)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
}
err = ck.Shutdown()
err = ck.Shutdown(ctx)
if err != nil {
t.Errorf("expected error equals not nil, but got %v", err)
@ -268,6 +277,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
}
func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
ctx := context.Background()
appName := "streamConsumer"
tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -289,13 +299,13 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
}
err = ck.Shutdown()
err = ck.Shutdown(ctx)
if err == nil {
t.Errorf("expected error equals nil, but got %v", err)

View file

@ -52,19 +52,17 @@ type Checkpoint struct {
}
// GetCheckpoint fetches the checkpoint for a particular Shard.
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
ctx := context.Background()
func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
val, _ := c.client.Get(ctx, c.key(streamName, shardID)).Result()
return val, nil
}
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon fail over, record processing is resumed from this point.
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
func (c *Checkpoint) SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
ctx := context.Background()
err := c.client.Set(ctx, c.key(streamName, shardID), sequenceNumber, 0).Err()
if err != nil {
return err

View file

@ -3,6 +3,7 @@
package redis
import (
"context"
"testing"
"github.com/alicebob/miniredis"
@ -34,10 +35,10 @@ func Test_CheckpointLifecycle(t *testing.T) {
}
// set
_ = c.SetCheckpoint("streamName", "shardID", "testSeqNum")
_ = c.SetCheckpoint(context.Background(), "streamName", "shardID", "testSeqNum")
// get
val, err := c.GetCheckpoint("streamName", "shardID")
val, err := c.GetCheckpoint(context.Background(), "streamName", "shardID")
if err != nil {
t.Fatalf("get checkpoint error: %v", err)
}
@ -52,7 +53,7 @@ func Test_SetEmptySeqNum(t *testing.T) {
t.Fatalf("new checkpoint error: %v", err)
}
err = c.SetCheckpoint("streamName", "shardID", "")
err = c.SetCheckpoint(context.Background(), "streamName", "shardID", "")
if err == nil {
t.Fatalf("should not allow empty sequence number")
}

1
worker.go Normal file
View file

@ -0,0 +1 @@
package consumer