diff --git a/README.md b/README.md index 779b560..ec33b53 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,34 @@ CREATE TABLE kinesis_consumer ( The table name has to be the same that you specify when creating the checkpoint. The primary key composed by namespace and shard_id is mandatory in order to the checkpoint run without issues and also to ensure data integrity. +### Mysql Checkpoint + +The Mysql checkpoint requires Table Name, App Name, Stream Name and ConnectionString (just like the Postgres checkpoint!): + +```go +import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/mysql" + +// mysql checkpoint +ck, err := checkpoint.New(app, table, connStr) +if err != nil { + log.Fatalf("new checkpoint error: %v", err) +} + +``` + +To leverage the Mysql checkpoint we'll also need to create a table: + +```sql +CREATE TABLE kinesis_consumer ( + namespace varchar(255) NOT NULL, + shard_id varchar(255) NOT NULL, + sequence_number numeric(65,0) NOT NULL, + CONSTRAINT kinesis_consumer_pk PRIMARY KEY (namespace, shard_id) +); +``` + +The table name has to be the same that you specify when creating the checkpoint. The primary key composed by namespace and shard_id is mandatory in order to the checkpoint run without issues and also to ensure data integrity. + ## Options The consumer allows the following optional overrides. diff --git a/checkpoint/mysql/mysql.go b/checkpoint/mysql/mysql.go new file mode 100644 index 0000000..f4a27d2 --- /dev/null +++ b/checkpoint/mysql/mysql.go @@ -0,0 +1,158 @@ +package mysql + +import ( + "database/sql" + "errors" + "fmt" + "sync" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +type key struct { + streamName string + shardID string +} + +// Option is used to override defaults when creating a new Checkpoint +type Option func(*Checkpoint) + +// WithMaxInterval sets the flush interval +func WithMaxInterval(maxInterval time.Duration) Option { + return func(c *Checkpoint) { + c.maxInterval = maxInterval + } +} + +// Checkpoint stores and retrieves the last evaluated key from a DDB scan +type Checkpoint struct { + appName string + tableName string + conn *sql.DB + mu *sync.Mutex // protects the checkpoints + done chan struct{} + checkpoints map[key]string + maxInterval time.Duration +} + +// New returns a checkpoint that uses Mysql for underlying storage +// Using connectionStr turn it more flexible to use specific db configs +func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, error) { + if appName == "" { + return nil, errors.New("application name not defined") + } + + if tableName == "" { + return nil, errors.New("table name not defined") + } + + conn, err := sql.Open("mysql", connectionStr) + + if err != nil { + return nil, err + } + + ck := &Checkpoint{ + conn: conn, + appName: appName, + tableName: tableName, + done: make(chan struct{}), + maxInterval: 1 * time.Minute, + mu: new(sync.Mutex), + checkpoints: map[key]string{}, + } + + for _, opt := range opts { + opt(ck) + } + + go ck.loop() + + return ck, nil +} + +// GetMaxInterval returns the maximum interval before the checkpoint +func (c *Checkpoint) GetMaxInterval() time.Duration { + return c.maxInterval +} + +// Get determines if a checkpoint for a particular Shard exists. +// Typically used to determine whether we should start processing the shard with +// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). +func (c *Checkpoint) Get(streamName, shardID string) (string, error) { + namespace := fmt.Sprintf("%s-%s", c.appName, streamName) + + var sequenceNumber string + getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) //nolint: gas, it replaces only the table name + err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) + + if err != nil { + if err == sql.ErrNoRows { + return "", nil + } + return "", err + } + + return sequenceNumber, nil +} + +// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). +// Upon failover, record processing is resumed from this point. +func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error { + c.mu.Lock() + defer c.mu.Unlock() + + if sequenceNumber == "" { + return fmt.Errorf("sequence number should not be empty") + } + + key := key{ + streamName: streamName, + shardID: shardID, + } + + c.checkpoints[key] = sequenceNumber + + return nil +} + +// Shutdown the checkpoint. Save any in-flight data. +func (c *Checkpoint) Shutdown() error { + defer c.conn.Close() + + c.done <- struct{}{} + + return c.save() +} + +func (c *Checkpoint) loop() { + tick := time.NewTicker(c.maxInterval) + defer tick.Stop() + defer close(c.done) + + for { + select { + case <-tick.C: + c.save() + case <-c.done: + return + } + } +} + +func (c *Checkpoint) save() error { + c.mu.Lock() + defer c.mu.Unlock() + + //nolint: gas, it replaces only the table name + upsertCheckpoint := fmt.Sprintf(`REPLACE INTO %s (namespace, shard_id, sequence_number) VALUES (?, ?, ?)`, c.tableName) + + for key, sequenceNumber := range c.checkpoints { + if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil { + return err + } + } + + return nil +} diff --git a/checkpoint/mysql/mysql_databaseutils_test.go b/checkpoint/mysql/mysql_databaseutils_test.go new file mode 100644 index 0000000..d3b8d1d --- /dev/null +++ b/checkpoint/mysql/mysql_databaseutils_test.go @@ -0,0 +1,7 @@ +package mysql + +import "database/sql" + +func (c *Checkpoint) SetConn(conn *sql.DB) { + c.conn = conn +} diff --git a/checkpoint/mysql/mysql_test.go b/checkpoint/mysql/mysql_test.go new file mode 100644 index 0000000..137726c --- /dev/null +++ b/checkpoint/mysql/mysql_test.go @@ -0,0 +1,304 @@ +package mysql + +import ( + "database/sql" + "fmt" + "testing" + "time" + + "github.com/pkg/errors" + sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" +) + +func TestNew(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + ck, err := New(appName, tableName, connString) + + if ck == nil { + t.Errorf("expected checkpointer not equal nil, but got %v", ck) + } + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + ck.Shutdown() +} + +func TestNew_AppNameEmpty(t *testing.T) { + appName := "" + tableName := "checkpoint" + connString := "" + ck, err := New(appName, tableName, connString) + + if ck != nil { + t.Errorf("expected checkpointer equal nil, but got %v", ck) + } + if err == nil { + t.Errorf("expected error equals not nil, but got %v", err) + } +} + +func TestNew_TableNameEmpty(t *testing.T) { + appName := "streamConsumer" + tableName := "" + connString := "" + ck, err := New(appName, tableName, connString) + + if ck != nil { + t.Errorf("expected checkpointer equal nil, but got %v", ck) + } + if err == nil { + t.Errorf("expected error equals not nil, but got %v", err) + } +} + +func TestNew_WithMaxIntervalOption(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + maxInterval := time.Second + ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval)) + + if ck == nil { + t.Errorf("expected checkpointer not equal nil, but got %v", ck) + } + if ck.GetMaxInterval() != time.Second { + t.Errorf("expected max interval equals %v, but got %v", maxInterval, ck.GetMaxInterval()) + } + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Get(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + rows := []string{"sequence_number"} + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedRows := sqlmock.NewRows(rows) + expectedRows.AddRow(expectedSequenceNumber) + expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`, + tableName) + mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows) + + gotSequenceNumber, err := ck.Get(streamName, shardID) + + if gotSequenceNumber != expectedSequenceNumber { + t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber) + } + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Get_NoRows(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + streamName := "myStreamName" + shardID := "shardId-00000000" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`, + tableName) + mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows) + + gotSequenceNumber, err := ck.Get(streamName, shardID) + + if gotSequenceNumber != "" { + t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) + } + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Get_QueryError(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + streamName := "myStreamName" + shardID := "shardId-00000000" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`, + tableName) + mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error")) + + gotSequenceNumber, err := ck.Get(streamName, shardID) + + if gotSequenceNumber != "" { + t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) + } + if err == nil { + t.Errorf("expected error equals not nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Set(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194" + maxInterval := time.Second + ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + + err = ck.Set(streamName, shardID, expectedSequenceNumber) + + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "" + maxInterval := time.Second + ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + + err = ck.Set(streamName, shardID, expectedSequenceNumber) + + if err == nil { + t.Errorf("expected error equals not nil, but got %v", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Shutdown(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName) + result := sqlmock.NewResult(0, 1) + mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result) + + err = ck.Set(streamName, shardID, expectedSequenceNumber) + + if err != nil { + t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) + } + + err = ck.Shutdown() + + if err != nil { + t.Errorf("expected error equals not nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestCheckpoint_Shutdown_SaveError(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "user:password@/dbname" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName) + mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error")) + + err = ck.Set(streamName, shardID, expectedSequenceNumber) + + if err != nil { + t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) + } + + err = ck.Shutdown() + + if err == nil { + t.Errorf("expected error equals nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/examples/consumer/cp-mysql/README.md b/examples/consumer/cp-mysql/README.md new file mode 100644 index 0000000..dd3b317 --- /dev/null +++ b/examples/consumer/cp-mysql/README.md @@ -0,0 +1,21 @@ +# Consumer with mysl checkpoint + +Read records from the Kinesis stream using mysql as checkpoint + +## Environment Variables + +Export the required environment vars for connecting to the Kinesis stream: + +```shell +export AWS_ACCESS_KEY= +export AWS_REGION= +export AWS_SECRET_KEY= +``` + +## Run the consumer + + go run main.go --app --stream --table --connection + +Connection string should look something like + + user:password@/dbname \ No newline at end of file diff --git a/examples/consumer/cp-mysql/main.go b/examples/consumer/cp-mysql/main.go new file mode 100644 index 0000000..386ea94 --- /dev/null +++ b/examples/consumer/cp-mysql/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "expvar" + "flag" + "fmt" + "log" + "os" + "os/signal" + + consumer "github.com/harlow/kinesis-consumer" + checkpoint "github.com/harlow/kinesis-consumer/checkpoint/mysql" +) + +func main() { + var ( + app = flag.String("app", "", "Consumer app name") + stream = flag.String("stream", "", "Stream name") + table = flag.String("table", "", "Table name") + connStr = flag.String("connection", "", "Connection Str") + ) + flag.Parse() + + // mysql checkpoint + ck, err := checkpoint.New(*app, *table, *connStr) + if err != nil { + log.Fatalf("checkpoint error: %v", err) + } + + var counter = expvar.NewMap("counters") + + // consumer + c, err := consumer.New( + *stream, + consumer.WithCheckpoint(ck), + consumer.WithCounter(counter), + ) + if err != nil { + log.Fatalf("consumer error: %v", err) + } + + // use cancel func to signal shutdown + ctx, cancel := context.WithCancel(context.Background()) + + // trap SIGINT, wait to trigger shutdown + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + + go func() { + <-signals + cancel() + }() + + // scan stream + err = c.Scan(ctx, func(r *consumer.Record) error { + fmt.Println(string(r.Data)) + return nil + }) + + if err != nil { + log.Fatalf("scan error: %v", err) + } + + if err := ck.Shutdown(); err != nil { + log.Fatalf("checkpoint shutdown error: %v", err) + } +}