#23 threads context through to stores
This commit is contained in:
parent
8c86963bd7
commit
8a4cb38940
15 changed files with 98 additions and 87 deletions
|
|
@ -155,7 +155,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
// for each record and checkpoints the progress of scan.
|
||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||
// get last seq number from checkpoint
|
||||
lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID)
|
||||
lastSeqNum, err := c.group.GetCheckpoint(ctx, c.streamName, shardID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get checkpoint error: %w", err)
|
||||
}
|
||||
|
|
@ -223,7 +223,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
}
|
||||
|
||||
if !errors.Is(err, ErrSkipCheckpoint) {
|
||||
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
if err := c.group.SetCheckpoint(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
c.counter.Add("checkpoint", 1)
|
||||
|
|
|
|||
|
|
@ -7,14 +7,13 @@ import (
|
|||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
alog "github.com/apex/log"
|
||||
"github.com/apex/log/handlers/text"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
|
|
@ -38,25 +37,7 @@ func init() {
|
|||
}()
|
||||
}
|
||||
|
||||
// A myLogger provides a minimalistic logger satisfying the Logger interface.
|
||||
type myLogger struct {
|
||||
logger alog.Logger
|
||||
}
|
||||
|
||||
// Log logs the parameters to the stdlib logger. See log.Println.
|
||||
func (l *myLogger) Log(args ...interface{}) {
|
||||
l.logger.Infof("producer: %v", args...)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Wrap myLogger around apex logger
|
||||
myLog := &myLogger{
|
||||
logger: alog.Logger{
|
||||
Handler: text.New(os.Stdout),
|
||||
Level: alog.DebugLevel,
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
app = flag.String("app", "", "Consumer app name")
|
||||
stream = flag.String("stream", "", "Stream name")
|
||||
|
|
@ -100,7 +81,7 @@ func main() {
|
|||
c, err := consumer.New(
|
||||
*stream,
|
||||
consumer.WithStore(ddb),
|
||||
consumer.WithLogger(myLog),
|
||||
consumer.WithLogger(slog.Default()),
|
||||
consumer.WithCounter(counter),
|
||||
consumer.WithClient(client),
|
||||
)
|
||||
|
|
@ -129,7 +110,7 @@ func main() {
|
|||
log.Fatalf("scan error: %v", err)
|
||||
}
|
||||
|
||||
if err := ddb.Shutdown(); err != nil {
|
||||
if err := ddb.Shutdown(ctx); err != nil {
|
||||
log.Fatalf("storage shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := checkpointStore.Shutdown(); err != nil {
|
||||
if err := checkpointStore.Shutdown(context.Background()); err != nil {
|
||||
slog.Error("store shutdown error", slog.String("error", err.Error()))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
|
|
|||
4
group.go
4
group.go
|
|
@ -9,6 +9,6 @@ import (
|
|||
// Group interface used to manage which shard to process
|
||||
type Group interface {
|
||||
Start(ctx context.Context, shardc chan types.Shard)
|
||||
GetCheckpoint(streamName, shardID string) (string, error)
|
||||
SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
||||
GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error)
|
||||
SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error
|
||||
}
|
||||
|
|
|
|||
12
store.go
12
store.go
|
|
@ -1,13 +1,17 @@
|
|||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Store interface used to persist scan progress
|
||||
type Store interface {
|
||||
GetCheckpoint(streamName, shardID string) (string, error)
|
||||
SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
||||
GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error)
|
||||
SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error
|
||||
}
|
||||
|
||||
// noopStore implements the storage interface with discard
|
||||
type noopStore struct{}
|
||||
|
||||
func (n noopStore) GetCheckpoint(string, string) (string, error) { return "", nil }
|
||||
func (n noopStore) SetCheckpoint(string, string, string) error { return nil }
|
||||
func (n noopStore) GetCheckpoint(context.Context, string, string) (string, error) { return "", nil }
|
||||
func (n noopStore) SetCheckpoint(context.Context, string, string, string) error { return nil }
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ type item struct {
|
|||
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
|
||||
// Typically used to determine whether we should start processing the shard with
|
||||
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
|
||||
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
||||
func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
|
||||
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
|
||||
|
||||
params := &dynamodb.GetItemInput{
|
||||
|
|
@ -106,10 +106,10 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
|||
},
|
||||
}
|
||||
|
||||
resp, err := c.client.GetItem(context.Background(), params)
|
||||
resp, err := c.client.GetItem(ctx, params)
|
||||
if err != nil {
|
||||
if c.retryer.ShouldRetry(err) {
|
||||
return c.GetCheckpoint(streamName, shardID)
|
||||
return c.GetCheckpoint(ctx, streamName, shardID)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
|
@ -121,7 +121,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
|||
|
||||
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
|
||||
// Upon fail over, record processing is resumed from this point.
|
||||
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
||||
func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
|
@ -139,12 +139,13 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e
|
|||
}
|
||||
|
||||
// Shutdown the checkpoint. Save any in-flight data.
|
||||
func (c *Checkpoint) Shutdown() error {
|
||||
func (c *Checkpoint) Shutdown(ctx context.Context) error {
|
||||
c.done <- struct{}{}
|
||||
return c.save()
|
||||
return c.save(ctx)
|
||||
}
|
||||
|
||||
func (c *Checkpoint) loop() {
|
||||
ctx := context.Background()
|
||||
tick := time.NewTicker(c.maxInterval)
|
||||
defer tick.Stop()
|
||||
defer close(c.done)
|
||||
|
|
@ -152,14 +153,14 @@ func (c *Checkpoint) loop() {
|
|||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
_ = c.save()
|
||||
_ = c.save(ctx)
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Checkpoint) save() error {
|
||||
func (c *Checkpoint) save(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
|
@ -175,7 +176,7 @@ func (c *Checkpoint) save() error {
|
|||
}
|
||||
|
||||
_, err = c.client.PutItem(
|
||||
context.TODO(),
|
||||
ctx,
|
||||
&dynamodb.PutItemInput{
|
||||
TableName: aws.String(c.tableName),
|
||||
Item: item,
|
||||
|
|
@ -184,7 +185,7 @@ func (c *Checkpoint) save() error {
|
|||
if !c.retryer.ShouldRetry(err) {
|
||||
return err
|
||||
}
|
||||
return c.save()
|
||||
return c.save(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
|
@ -21,7 +22,7 @@ type Store struct {
|
|||
}
|
||||
|
||||
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
|
||||
func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
||||
func (c *Store) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
|
||||
if sequenceNumber == "" {
|
||||
return fmt.Errorf("sequence number should not be empty")
|
||||
}
|
||||
|
|
@ -32,7 +33,7 @@ func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
|||
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
|
||||
// Typically, this is used to determine whether processing should start with TRIM_HORIZON or AFTER_SEQUENCE_NUMBER
|
||||
// (if checkpoint exists).
|
||||
func (c *Store) GetCheckpoint(streamName, shardID string) (string, error) {
|
||||
func (c *Store) GetCheckpoint(_ context.Context, streamName, shardID string) (string, error) {
|
||||
val, ok := c.Load(streamName + ":" + shardID)
|
||||
if !ok {
|
||||
return "", nil
|
||||
|
|
|
|||
|
|
@ -1,17 +1,19 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_CheckpointLifecycle(t *testing.T) {
|
||||
c := New()
|
||||
ctx := context.Background()
|
||||
|
||||
// set
|
||||
_ = c.SetCheckpoint("streamName", "shardID", "testSeqNum")
|
||||
_ = c.SetCheckpoint(ctx, "streamName", "shardID", "testSeqNum")
|
||||
|
||||
// get
|
||||
val, err := c.GetCheckpoint("streamName", "shardID")
|
||||
val, err := c.GetCheckpoint(ctx, "streamName", "shardID")
|
||||
if err != nil {
|
||||
t.Fatalf("get checkpoint error: %v", err)
|
||||
}
|
||||
|
|
@ -22,8 +24,9 @@ func Test_CheckpointLifecycle(t *testing.T) {
|
|||
|
||||
func Test_SetEmptySeqNum(t *testing.T) {
|
||||
c := New()
|
||||
ctx := context.Background()
|
||||
|
||||
err := c.SetCheckpoint("streamName", "shardID", "")
|
||||
err := c.SetCheckpoint(ctx, "streamName", "shardID", "")
|
||||
if err == nil || err.Error() != "sequence number should not be empty" {
|
||||
t.Fatalf("should not allow empty sequence number")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -81,12 +82,12 @@ func (c *Checkpoint) GetMaxInterval() time.Duration {
|
|||
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
|
||||
// Typically used to determine whether we should start processing the shard with
|
||||
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
|
||||
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
||||
func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
|
||||
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
|
||||
|
||||
var sequenceNumber string
|
||||
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) // nolint: gas, it replaces only the table name
|
||||
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
|
||||
err := c.conn.QueryRowContext(ctx, getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
|
|
@ -100,7 +101,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
|||
|
||||
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
|
||||
// Upon fail over, record processing is resumed from this point.
|
||||
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
||||
func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
|
@ -73,6 +74,7 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_GetCheckpoint(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "user:password@/dbname"
|
||||
|
|
@ -98,7 +100,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
|
|||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
|
||||
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != expectedSequenceNumber {
|
||||
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
|
||||
|
|
@ -113,6 +115,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_Get_NoRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "user:password@/dbname"
|
||||
|
|
@ -134,7 +137,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
|
|||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
|
||||
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != "" {
|
||||
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
|
||||
|
|
@ -149,6 +152,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_Get_QueryError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "user:password@/dbname"
|
||||
|
|
@ -170,7 +174,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
|
|||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
|
||||
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != "" {
|
||||
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
|
||||
|
|
@ -185,6 +189,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_SetCheckpoint(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "user:password@/dbname"
|
||||
|
|
@ -197,7 +202,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
|
|||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
|
||||
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
|
|
@ -206,6 +211,7 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "user:password@/dbname"
|
||||
|
|
@ -218,7 +224,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
|
|||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
|
||||
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("expected error equals not nil, but got %v", err)
|
||||
|
|
@ -227,6 +233,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_Shutdown(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "user:password@/dbname"
|
||||
|
|
@ -249,7 +256,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
|
|||
result := sqlmock.NewResult(0, 1)
|
||||
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
|
||||
|
||||
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
|
||||
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
|
||||
|
|
@ -266,6 +273,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "user:password@/dbname"
|
||||
|
|
@ -287,7 +295,7 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
|
|||
expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName)
|
||||
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
|
||||
|
||||
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
|
||||
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -88,12 +89,12 @@ func (c *Checkpoint) GetMaxInterval() time.Duration {
|
|||
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
|
||||
// Typically used to determine whether we should start processing the shard with
|
||||
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
|
||||
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
||||
func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
|
||||
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
|
||||
|
||||
var sequenceNumber string
|
||||
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) // nolint: gas, it replaces only the table name
|
||||
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
|
||||
err := c.conn.QueryRowContext(ctx, getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
|
|
@ -107,7 +108,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
|||
|
||||
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
|
||||
// Upon fail over, record processing is resumed from this point.
|
||||
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
||||
func (c *Checkpoint) SetCheckpoint(_ context.Context, streamName, shardID, sequenceNumber string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
|
@ -126,15 +127,16 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e
|
|||
}
|
||||
|
||||
// Shutdown the checkpoint. Save any in-flight data.
|
||||
func (c *Checkpoint) Shutdown() error {
|
||||
func (c *Checkpoint) Shutdown(ctx context.Context) error {
|
||||
defer c.conn.Close()
|
||||
|
||||
c.done <- struct{}{}
|
||||
|
||||
return c.save()
|
||||
return c.save(ctx)
|
||||
}
|
||||
|
||||
func (c *Checkpoint) loop() {
|
||||
ctx := context.Background()
|
||||
tick := time.NewTicker(c.maxInterval)
|
||||
defer tick.Stop()
|
||||
defer close(c.done)
|
||||
|
|
@ -142,14 +144,14 @@ func (c *Checkpoint) loop() {
|
|||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
_ = c.save()
|
||||
_ = c.save(ctx)
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Checkpoint) save() error {
|
||||
func (c *Checkpoint) save(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
|
@ -159,10 +161,10 @@ func (c *Checkpoint) save() error {
|
|||
ON CONFLICT (namespace, shard_id)
|
||||
DO
|
||||
UPDATE
|
||||
SET sequence_number= $3;`, c.tableName)
|
||||
SET sequence_number=$3;`, c.tableName)
|
||||
|
||||
for key, sequenceNumber := range c.checkpoints {
|
||||
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
|
||||
if _, err := c.conn.ExecContext(ctx, upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
|
@ -13,6 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -24,7 +26,7 @@ func TestNew(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
_ = ck.Shutdown()
|
||||
_ = ck.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestNew_AppNameEmpty(t *testing.T) {
|
||||
|
|
@ -56,6 +58,7 @@ func TestNew_TableNameEmpty(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNew_WithMaxIntervalOption(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -71,10 +74,11 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
_ = ck.Shutdown()
|
||||
_ = ck.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestCheckpoint_GetCheckpoint(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -100,7 +104,7 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
|
|||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
|
||||
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != expectedSequenceNumber {
|
||||
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
|
||||
|
|
@ -111,10 +115,11 @@ func TestCheckpoint_GetCheckpoint(t *testing.T) {
|
|||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
_ = ck.Shutdown()
|
||||
_ = ck.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestCheckpoint_Get_NoRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -136,7 +141,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
|
|||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
|
||||
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != "" {
|
||||
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
|
||||
|
|
@ -151,6 +156,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_Get_QueryError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -172,7 +178,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
|
|||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
|
||||
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(streamName, shardID)
|
||||
gotSequenceNumber, err := ck.GetCheckpoint(ctx, streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != "" {
|
||||
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
|
||||
|
|
@ -183,10 +189,11 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
|
|||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
_ = ck.Shutdown()
|
||||
_ = ck.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestCheckpoint_SetCheckpoint(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -199,15 +206,16 @@ func TestCheckpoint_SetCheckpoint(t *testing.T) {
|
|||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
|
||||
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
_ = ck.Shutdown()
|
||||
_ = ck.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -220,15 +228,16 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
|
|||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
|
||||
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("expected error equals not nil, but got %v", err)
|
||||
}
|
||||
_ = ck.Shutdown()
|
||||
_ = ck.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestCheckpoint_Shutdown(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -251,13 +260,13 @@ func TestCheckpoint_Shutdown(t *testing.T) {
|
|||
result := sqlmock.NewResult(0, 1)
|
||||
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
|
||||
|
||||
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
|
||||
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.Shutdown()
|
||||
err = ck.Shutdown(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals not nil, but got %v", err)
|
||||
|
|
@ -268,6 +277,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
|
|
@ -289,13 +299,13 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
|
|||
expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName)
|
||||
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
|
||||
|
||||
err = ck.SetCheckpoint(streamName, shardID, expectedSequenceNumber)
|
||||
err = ck.SetCheckpoint(ctx, streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.Shutdown()
|
||||
err = ck.Shutdown(ctx)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
|
|
|
|||
|
|
@ -52,19 +52,17 @@ type Checkpoint struct {
|
|||
}
|
||||
|
||||
// GetCheckpoint fetches the checkpoint for a particular Shard.
|
||||
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
||||
ctx := context.Background()
|
||||
func (c *Checkpoint) GetCheckpoint(ctx context.Context, streamName, shardID string) (string, error) {
|
||||
val, _ := c.client.Get(ctx, c.key(streamName, shardID)).Result()
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
|
||||
// Upon fail over, record processing is resumed from this point.
|
||||
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
||||
func (c *Checkpoint) SetCheckpoint(ctx context.Context, streamName, shardID, sequenceNumber string) error {
|
||||
if sequenceNumber == "" {
|
||||
return fmt.Errorf("sequence number should not be empty")
|
||||
}
|
||||
ctx := context.Background()
|
||||
err := c.client.Set(ctx, c.key(streamName, shardID), sequenceNumber, 0).Err()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
|
|
@ -34,10 +35,10 @@ func Test_CheckpointLifecycle(t *testing.T) {
|
|||
}
|
||||
|
||||
// set
|
||||
_ = c.SetCheckpoint("streamName", "shardID", "testSeqNum")
|
||||
_ = c.SetCheckpoint(context.Background(), "streamName", "shardID", "testSeqNum")
|
||||
|
||||
// get
|
||||
val, err := c.GetCheckpoint("streamName", "shardID")
|
||||
val, err := c.GetCheckpoint(context.Background(), "streamName", "shardID")
|
||||
if err != nil {
|
||||
t.Fatalf("get checkpoint error: %v", err)
|
||||
}
|
||||
|
|
@ -52,7 +53,7 @@ func Test_SetEmptySeqNum(t *testing.T) {
|
|||
t.Fatalf("new checkpoint error: %v", err)
|
||||
}
|
||||
|
||||
err = c.SetCheckpoint("streamName", "shardID", "")
|
||||
err = c.SetCheckpoint(context.Background(), "streamName", "shardID", "")
|
||||
if err == nil {
|
||||
t.Fatalf("should not allow empty sequence number")
|
||||
}
|
||||
|
|
|
|||
1
worker.go
Normal file
1
worker.go
Normal file
|
|
@ -0,0 +1 @@
|
|||
package consumer
|
||||
Loading…
Reference in a new issue