diff --git a/consumer.go b/consumer.go index 2e3a7b9..bff3fb8 100644 --- a/consumer.go +++ b/consumer.go @@ -155,7 +155,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { // get last seq number from checkpoint - lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID) + lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID) if err != nil { return fmt.Errorf("get checkpoint error: %w", err) } @@ -223,7 +223,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } if !errors.Is(err, ErrSkipCheckpoint) { - if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { + if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil { return err } c.counter.Add("checkpoint", 1) diff --git a/examples/consumer-dynamodb/main.go b/examples/consumer-dynamodb/main.go index b3edb66..59b3603 100644 --- a/examples/consumer-dynamodb/main.go +++ b/examples/consumer-dynamodb/main.go @@ -7,14 +7,13 @@ import ( "flag" "fmt" "log" + "log/slog" "net" "net/http" "os" "os/signal" "time" - alog "github.com/apex/log" - "github.com/apex/log/handlers/text" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/dynamodb" @@ -38,25 +37,7 @@ func init() { }() } -// A myLogger provides a minimalistic logger satisfying the Logger interface. -type myLogger struct { - logger alog.Logger -} - -// Log logs the parameters to the stdlib logger. See log.Println. -func (l *myLogger) Log(args ...interface{}) { - l.logger.Infof("producer: %v", args...) -} - func main() { - // Wrap myLogger around apex logger - myLog := &myLogger{ - logger: alog.Logger{ - Handler: text.New(os.Stdout), - Level: alog.DebugLevel, - }, - } - var ( app = flag.String("app", "", "Consumer app name") stream = flag.String("stream", "", "Stream name") @@ -100,7 +81,7 @@ func main() { c, err := consumer.New( *stream, consumer.WithStore(ddb), - consumer.WithLogger(myLog), + consumer.WithLogger(slog.Default()), consumer.WithCounter(counter), consumer.WithClient(client), ) @@ -129,7 +110,7 @@ func main() { log.Fatalf("scan error: %v", err) } - if err := ddb.Shutdown(); err != nil { + if err := ddb.Shutdown(ctx); err != nil { log.Fatalf("storage shutdown error: %v", err) } } diff --git a/examples/consumer-postgres/main.go b/examples/consumer-postgres/main.go index 1c459ac..1296469 100644 --- a/examples/consumer-postgres/main.go +++ b/examples/consumer-postgres/main.go @@ -85,7 +85,7 @@ func main() { os.Exit(1) } - if err := checkpointStore.Shutdown(); err != nil { + if err := checkpointStore.Shutdown(context.Background()); err != nil { slog.Error("store shutdown error", slog.String("error", err.Error())) os.Exit(1) } diff --git a/group.go b/group.go index a092dc3..0ccc2e0 100644 --- a/group.go +++ b/group.go @@ -9,6 +9,6 @@ import ( // Group interface used to manage which shard to process type Group interface { Start(ctx context.Context, shardc chan types.Shard) - GetCheckpoint(streamName, shardID string) (string, error) - SetCheckpoint(streamName, shardID, sequenceNumber string) error + GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) + SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error } diff --git a/store.go b/store.go index 4b4a9d4..24ccb50 100644 --- a/store.go +++ b/store.go @@ -1,13 +1,17 @@ package consumer +import ( + "context" +) + // Store interface used to persist scan progress type Store interface { - GetCheckpoint(streamName, shardID string) (string, error) - SetCheckpoint(streamName, shardID, sequenceNumber string) error + GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) + SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error } // noopStore implements the storage interface with discard type noopStore struct{} -func (n noopStore) GetCheckpoint(string, string) (string, error) { return "", nil } -func (n noopStore) SetCheckpoint(string, string, string) error { return nil } +func (n noopStore) GetCheckpoint(context.Context, string, string) (string, error) { return "", nil } +func (n noopStore) SetCheckpoint(context.Context, string, string, string) error { return nil } diff --git a/store/ddb/ddb.go b/store/ddb/ddb.go index d9de968..f6ba67a 100644 --- a/store/ddb/ddb.go +++ b/store/ddb/ddb.go @@ -94,7 +94,7 @@ type item struct { // GetCheckpoint determines if a checkpoint for a particular Shard exists. // Typically used to determine whether we should start processing the shard with // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). -func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { +func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) { namespace := fmt.Sprintf("%s-%s", c.appName, streamName) params := &dynamodb.GetItemInput{ @@ -106,10 +106,10 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { }, } - resp, err := c.client.GetItem(context.Background(), params) + resp, err := c.client.GetItem(ctx, params) if err != nil { if c.retryer.ShouldRetry(err) { - return c.GetCheckpoint(streamName, shardID) + return c.GetCheckpoint(ctx, streamName, shardID) } return "", err } @@ -121,7 +121,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon fail over, record processing is resumed from this point. -func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { +func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error { c.mu.Lock() defer c.mu.Unlock() @@ -139,12 +139,13 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e } // Shutdown the checkpoint. Save any in-flight data. -func (c *Checkpoint) Shutdown() error { +func (c *Checkpoint) Shutdown(ctx context.Context) error { c.done <- struct{}{} - return c.save() + return c.save(ctx) } func (c *Checkpoint) loop() { + ctx := context.Background() tick := time.NewTicker(c.maxInterval) defer tick.Stop() defer close(c.done) @@ -152,14 +153,14 @@ func (c *Checkpoint) loop() { for { select { case <-tick.C: - _ = c.save() + _ = c.save(ctx) case <-c.done: return } } } -func (c *Checkpoint) save() error { +func (c *Checkpoint) save(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() @@ -175,7 +176,7 @@ func (c *Checkpoint) save() error { } _, err = c.client.PutItem( - context.TODO(), + ctx, &dynamodb.PutItemInput{ TableName: aws.String(c.tableName), Item: item, @@ -184,7 +185,7 @@ func (c *Checkpoint) save() error { if !c.retryer.ShouldRetry(err) { return err } - return c.save() + return c.save(ctx) } } diff --git a/store/memory/store.go b/store/memory/store.go index 22981f6..1e526ed 100644 --- a/store/memory/store.go +++ b/store/memory/store.go @@ -6,6 +6,7 @@ package store import ( + "context" "fmt" "sync" ) @@ -21,7 +22,7 @@ type Store struct { } // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). -func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error { +func (c *Store) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error { if sequenceNumber == "" { return fmt.Errorf("sequence number should not be empty") } @@ -32,7 +33,7 @@ func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error // GetCheckpoint determines if a checkpoint for a particular Shard exists. // Typically, this is used to determine whether processing should start with TRIM_HORIZON or AFTER_SEQUENCE_NUMBER // (if checkpoint exists). -func (c *Store) GetCheckpoint(streamName, shardID string) (string, error) { +func (c *Store) GetCheckpoint(_ context.Context, streamName, shardID string) (string, error) { val, ok := c.Load(streamName + ":" + shardID) if !ok { return "", nil diff --git a/store/memory/store_test.go b/store/memory/store_test.go index 05c006f..3e1fee0 100644 --- a/store/memory/store_test.go +++ b/store/memory/store_test.go @@ -1,17 +1,19 @@ package store import ( + "context" "testing" ) func Test_CheckpointLifecycle(t *testing.T) { c := New() + ctx := context.Background() // set - _ = c.SetCheckpoint("streamName", "shardID", "testSeqNum") + _ = c.SetCheckpoint(ctx, "streamName", "shardID", "testSeqNum") // get - val, err := c.GetCheckpoint("streamName", "shardID") + val, err := c.GetCheckpoint(ctx, "streamName", "shardID") if err != nil { t.Fatalf("get checkpoint error: %v", err) } @@ -22,8 +24,9 @@ func Test_CheckpointLifecycle(t *testing.T) { func Test_SetEmptySeqNum(t *testing.T) { c := New() + ctx := context.Background() - err := c.SetCheckpoint("streamName", "shardID", "") + err := c.SetCheckpoint(ctx, "streamName", "shardID", "") if err == nil || err.Error() != "sequence number should not be empty" { t.Fatalf("should not allow empty sequence number") } diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go index 3652f81..84b3379 100644 --- a/store/mysql/mysql.go +++ b/store/mysql/mysql.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "errors" "fmt" @@ -81,12 +82,12 @@ func (c *Checkpoint) GetMaxInterval() time.Duration { // GetCheckpoint determines if a checkpoint for a particular Shard exists. // Typically used to determine whether we should start processing the shard with // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). -func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { +func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) { namespace := fmt.Sprintf("%s-%s", c.appName, streamName) var sequenceNumber string getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) // nolint: gas, it replaces only the table name - err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) + err := c.conn.QueryRowContext(ctx, getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -100,7 +101,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon fail over, record processing is resumed from this point. -func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { +func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error { c.mu.Lock() defer c.mu.Unlock() diff --git a/store/mysql/mysql_test.go b/store/mysql/mysql_test.go index 5e7ac62..0986ee3 100644 --- a/store/mysql/mysql_test.go +++ b/store/mysql/mysql_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "fmt" "testing" @@ -73,6 +74,7 @@ func TestNew_WithMaxIntervalOption(t *testing.T) { } func TestCheckpoint_GetCheckpoint(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "user:password@/dbname" @@ -98,7 +100,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows) - gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) + gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID) if gotSequenceNumber != expectedSequenceNumber { t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber) @@ -113,6 +115,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) { } func TestCheckpoint_Get_NoRows(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "user:password@/dbname" @@ -134,7 +137,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows) - gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) + gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID) if gotSequenceNumber != "" { t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) @@ -149,6 +152,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) { } func TestCheckpoint_Get_QueryError(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "user:password@/dbname" @@ -170,7 +174,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error")) - gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) + gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID) if gotSequenceNumber != "" { t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) @@ -185,6 +189,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) { } func TestCheckpoint_SetCheckpoint(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "user:password@/dbname" @@ -197,7 +202,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) { t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) } - err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) + err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Errorf("expected error equals nil, but got %v", err) @@ -206,6 +211,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) { } func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "user:password@/dbname" @@ -218,7 +224,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) } - err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) + err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber) if err == nil { t.Errorf("expected error equals not nil, but got %v", err) @@ -227,6 +233,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { } func TestCheckpoint_Shutdown(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "user:password@/dbname" @@ -249,7 +256,7 @@ func TestCheckpoint_Shutdown(t *testing.T) { result := sqlmock.NewResult(0, 1) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result) - err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) + err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) @@ -266,6 +273,7 @@ func TestCheckpoint_Shutdown(t *testing.T) { } func TestCheckpoint_Shutdown_SaveError(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "user:password@/dbname" @@ -287,7 +295,7 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) { expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error")) - err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) + err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) diff --git a/store/postgres/postgres.go b/store/postgres/postgres.go index 909e09d..8e5fd3e 100644 --- a/store/postgres/postgres.go +++ b/store/postgres/postgres.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "database/sql" "errors" "fmt" @@ -88,12 +89,12 @@ func (c *Checkpoint) GetMaxInterval() time.Duration { // GetCheckpoint determines if a checkpoint for a particular Shard exists. // Typically used to determine whether we should start processing the shard with // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). -func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { +func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) { namespace := fmt.Sprintf("%s-%s", c.appName, streamName) var sequenceNumber string getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) // nolint: gas, it replaces only the table name - err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) + err := c.conn.QueryRowContext(ctx, getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -107,7 +108,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon fail over, record processing is resumed from this point. -func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { +func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error { c.mu.Lock() defer c.mu.Unlock() @@ -126,15 +127,16 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e } // Shutdown the checkpoint. Save any in-flight data. -func (c *Checkpoint) Shutdown() error { +func (c *Checkpoint) Shutdown(ctx context.Context) error { defer c.conn.Close() c.done <- struct{}{} - return c.save() + return c.save(ctx) } func (c *Checkpoint) loop() { + ctx := context.Background() tick := time.NewTicker(c.maxInterval) defer tick.Stop() defer close(c.done) @@ -142,14 +144,14 @@ func (c *Checkpoint) loop() { for { select { case <-tick.C: - _ = c.save() + _ = c.save(ctx) case <-c.done: return } } } -func (c *Checkpoint) save() error { +func (c *Checkpoint) save(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() @@ -159,10 +161,10 @@ func (c *Checkpoint) save() error { ON CONFLICT (namespace, shard_id) DO UPDATE - SET sequence_number= $3;`, c.tableName) + SET sequence_number=$3;`, c.tableName) for key, sequenceNumber := range c.checkpoints { - if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil { + if _, err := c.conn.ExecContext(ctx, upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil { return err } } diff --git a/store/postgres/postgres_test.go b/store/postgres/postgres_test.go index b5c44d5..0cc3353 100644 --- a/store/postgres/postgres_test.go +++ b/store/postgres/postgres_test.go @@ -3,6 +3,7 @@ package postgres import ( + "context" "database/sql" "fmt" "testing" @@ -13,6 +14,7 @@ import ( ) func TestNew(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -24,7 +26,7 @@ func TestNew(t *testing.T) { if err != nil { t.Errorf("expected error equals nil, but got %v", err) } - _ = ck.Shutdown() + _ = ck.Shutdown(ctx) } func TestNew_AppNameEmpty(t *testing.T) { @@ -56,6 +58,7 @@ func TestNew_TableNameEmpty(t *testing.T) { } func TestNew_WithMaxIntervalOption(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -71,10 +74,11 @@ func TestNew_WithMaxIntervalOption(t *testing.T) { if err != nil { t.Errorf("expected error equals nil, but got %v", err) } - _ = ck.Shutdown() + _ = ck.Shutdown(ctx) } func TestCheckpoint_GetCheckpoint(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -100,7 +104,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows) - gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) + gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID) if gotSequenceNumber != expectedSequenceNumber { t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber) @@ -111,10 +115,11 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) { if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } - _ = ck.Shutdown() + _ = ck.Shutdown(ctx) } func TestCheckpoint_Get_NoRows(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -136,7 +141,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows) - gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) + gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID) if gotSequenceNumber != "" { t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) @@ -151,6 +156,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) { } func TestCheckpoint_Get_QueryError(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -172,7 +178,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error")) - gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID) + gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID) if gotSequenceNumber != "" { t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) @@ -183,10 +189,11 @@ func TestCheckpoint_Get_QueryError(t *testing.T) { if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } - _ = ck.Shutdown() + _ = ck.Shutdown(ctx) } func TestCheckpoint_SetCheckpoint(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -199,15 +206,16 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) { t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) } - err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) + err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Errorf("expected error equals nil, but got %v", err) } - _ = ck.Shutdown() + _ = ck.Shutdown(ctx) } func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -220,15 +228,16 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) } - err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) + err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber) if err == nil { t.Errorf("expected error equals not nil, but got %v", err) } - _ = ck.Shutdown() + _ = ck.Shutdown(ctx) } func TestCheckpoint_Shutdown(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -251,13 +260,13 @@ func TestCheckpoint_Shutdown(t *testing.T) { result := sqlmock.NewResult(0, 1) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result) - err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) + err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) } - err = ck.Shutdown() + err = ck.Shutdown(ctx) if err != nil { t.Errorf("expected error equals not nil, but got %v", err) @@ -268,6 +277,7 @@ func TestCheckpoint_Shutdown(t *testing.T) { } func TestCheckpoint_Shutdown_SaveError(t *testing.T) { + ctx := context.Background() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -289,13 +299,13 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) { expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error")) - err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber) + err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) } - err = ck.Shutdown() + err = ck.Shutdown(ctx) if err == nil { t.Errorf("expected error equals nil, but got %v", err) diff --git a/store/redis/redis.go b/store/redis/redis.go index 463545a..27bf3f1 100644 --- a/store/redis/redis.go +++ b/store/redis/redis.go @@ -52,19 +52,17 @@ type Checkpoint struct { } // GetCheckpoint fetches the checkpoint for a particular Shard. -func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { - ctx := context.Background() +func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) { val, _ := c.client.Get(ctx, c.key(streamName, shardID)).Result() return val, nil } // SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon fail over, record processing is resumed from this point. -func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { +func (c *Checkpoint) SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error { if sequenceNumber == "" { return fmt.Errorf("sequence number should not be empty") } - ctx := context.Background() err := c.client.Set(ctx, c.key(streamName, shardID), sequenceNumber, 0).Err() if err != nil { return err diff --git a/store/redis/redis_test.go b/store/redis/redis_test.go index c8bb051..da333c3 100644 --- a/store/redis/redis_test.go +++ b/store/redis/redis_test.go @@ -3,6 +3,7 @@ package redis import ( + "context" "testing" "github.com/alicebob/miniredis" @@ -34,10 +35,10 @@ func Test_CheckpointLifecycle(t *testing.T) { } // set - _ = c.SetCheckpoint("streamName", "shardID", "testSeqNum") + _ = c.SetCheckpoint(context.Background(), "streamName", "shardID", "testSeqNum") // get - val, err := c.GetCheckpoint("streamName", "shardID") + val, err := c.GetCheckpoint(context.Background(), "streamName", "shardID") if err != nil { t.Fatalf("get checkpoint error: %v", err) } @@ -52,7 +53,7 @@ func Test_SetEmptySeqNum(t *testing.T) { t.Fatalf("new checkpoint error: %v", err) } - err = c.SetCheckpoint("streamName", "shardID", "") + err = c.SetCheckpoint(context.Background(), "streamName", "shardID", "") if err == nil { t.Fatalf("should not allow empty sequence number") } diff --git a/worker.go b/worker.go new file mode 100644 index 0000000..b78b46c --- /dev/null +++ b/worker.go @@ -0,0 +1 @@ +package consumer