Add Mysql support for checkpointing (#87)

This commit is contained in:
James Greenhill 2019-04-12 22:15:49 -07:00 committed by Harlow Ward
parent f7f98a4bc6
commit b48acfa5d4
6 changed files with 586 additions and 0 deletions

View file

@ -197,6 +197,34 @@ CREATE TABLE kinesis_consumer (
The table name has to be the same that you specify when creating the checkpoint. The primary key composed by namespace and shard_id is mandatory in order to the checkpoint run without issues and also to ensure data integrity.
### Mysql Checkpoint
The Mysql checkpoint requires Table Name, App Name, Stream Name and ConnectionString (just like the Postgres checkpoint!):
```go
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/mysql"
// mysql checkpoint
ck, err := checkpoint.New(app, table, connStr)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
```
To leverage the Mysql checkpoint we'll also need to create a table:
```sql
CREATE TABLE kinesis_consumer (
namespace varchar(255) NOT NULL,
shard_id varchar(255) NOT NULL,
sequence_number numeric(65,0) NOT NULL,
CONSTRAINT kinesis_consumer_pk PRIMARY KEY (namespace, shard_id)
);
```
The table name has to be the same that you specify when creating the checkpoint. The primary key composed by namespace and shard_id is mandatory in order to the checkpoint run without issues and also to ensure data integrity.
## Options
The consumer allows the following optional overrides.

158
checkpoint/mysql/mysql.go Normal file
View file

@ -0,0 +1,158 @@
package mysql
import (
"database/sql"
"errors"
"fmt"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
)
type key struct {
streamName string
shardID string
}
// Option is used to override defaults when creating a new Checkpoint
type Option func(*Checkpoint)
// WithMaxInterval sets the flush interval
func WithMaxInterval(maxInterval time.Duration) Option {
return func(c *Checkpoint) {
c.maxInterval = maxInterval
}
}
// Checkpoint stores and retrieves the last evaluated key from a DDB scan
type Checkpoint struct {
appName string
tableName string
conn *sql.DB
mu *sync.Mutex // protects the checkpoints
done chan struct{}
checkpoints map[key]string
maxInterval time.Duration
}
// New returns a checkpoint that uses Mysql for underlying storage
// Using connectionStr turn it more flexible to use specific db configs
func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, error) {
if appName == "" {
return nil, errors.New("application name not defined")
}
if tableName == "" {
return nil, errors.New("table name not defined")
}
conn, err := sql.Open("mysql", connectionStr)
if err != nil {
return nil, err
}
ck := &Checkpoint{
conn: conn,
appName: appName,
tableName: tableName,
done: make(chan struct{}),
maxInterval: 1 * time.Minute,
mu: new(sync.Mutex),
checkpoints: map[key]string{},
}
for _, opt := range opts {
opt(ck)
}
go ck.loop()
return ck, nil
}
// GetMaxInterval returns the maximum interval before the checkpoint
func (c *Checkpoint) GetMaxInterval() time.Duration {
return c.maxInterval
}
// Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) //nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil {
if err == sql.ErrNoRows {
return "", nil
}
return "", err
}
return sequenceNumber, nil
}
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
c.mu.Lock()
defer c.mu.Unlock()
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
key := key{
streamName: streamName,
shardID: shardID,
}
c.checkpoints[key] = sequenceNumber
return nil
}
// Shutdown the checkpoint. Save any in-flight data.
func (c *Checkpoint) Shutdown() error {
defer c.conn.Close()
c.done <- struct{}{}
return c.save()
}
func (c *Checkpoint) loop() {
tick := time.NewTicker(c.maxInterval)
defer tick.Stop()
defer close(c.done)
for {
select {
case <-tick.C:
c.save()
case <-c.done:
return
}
}
}
func (c *Checkpoint) save() error {
c.mu.Lock()
defer c.mu.Unlock()
//nolint: gas, it replaces only the table name
upsertCheckpoint := fmt.Sprintf(`REPLACE INTO %s (namespace, shard_id, sequence_number) VALUES (?, ?, ?)`, c.tableName)
for key, sequenceNumber := range c.checkpoints {
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,7 @@
package mysql
import "database/sql"
func (c *Checkpoint) SetConn(conn *sql.DB) {
c.conn = conn
}

View file

@ -0,0 +1,304 @@
package mysql
import (
"database/sql"
"fmt"
"testing"
"time"
"github.com/pkg/errors"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
func TestNew(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
ck, err := New(appName, tableName, connString)
if ck == nil {
t.Errorf("expected checkpointer not equal nil, but got %v", ck)
}
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
ck.Shutdown()
}
func TestNew_AppNameEmpty(t *testing.T) {
appName := ""
tableName := "checkpoint"
connString := ""
ck, err := New(appName, tableName, connString)
if ck != nil {
t.Errorf("expected checkpointer equal nil, but got %v", ck)
}
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
}
func TestNew_TableNameEmpty(t *testing.T) {
appName := "streamConsumer"
tableName := ""
connString := ""
ck, err := New(appName, tableName, connString)
if ck != nil {
t.Errorf("expected checkpointer equal nil, but got %v", ck)
}
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
}
func TestNew_WithMaxIntervalOption(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
maxInterval := time.Second
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if ck == nil {
t.Errorf("expected checkpointer not equal nil, but got %v", ck)
}
if ck.GetMaxInterval() != time.Second {
t.Errorf("expected max interval equals %v, but got %v", maxInterval, ck.GetMaxInterval())
}
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
ck.Shutdown()
}
func TestCheckpoint_Get(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
rows := []string{"sequence_number"}
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedRows := sqlmock.NewRows(rows)
expectedRows.AddRow(expectedSequenceNumber)
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`,
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
gotSequenceNumber, err := ck.Get(streamName, shardID)
if gotSequenceNumber != expectedSequenceNumber {
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
}
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
ck.Shutdown()
}
func TestCheckpoint_Get_NoRows(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`,
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
gotSequenceNumber, err := ck.Get(streamName, shardID)
if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
}
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
ck.Shutdown()
}
func TestCheckpoint_Get_QueryError(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`,
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
gotSequenceNumber, err := ck.Get(streamName, shardID)
if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
}
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
ck.Shutdown()
}
func TestCheckpoint_Set(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
maxInterval := time.Second
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
err = ck.Set(streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
ck.Shutdown()
}
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := ""
maxInterval := time.Second
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
err = ck.Set(streamName, shardID, expectedSequenceNumber)
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
ck.Shutdown()
}
func TestCheckpoint_Shutdown(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName)
result := sqlmock.NewResult(0, 1)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
err = ck.Set(streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
}
err = ck.Shutdown()
if err != nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
err = ck.Set(streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
}
err = ck.Shutdown()
if err == nil {
t.Errorf("expected error equals nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

View file

@ -0,0 +1,21 @@
# Consumer with mysl checkpoint
Read records from the Kinesis stream using mysql as checkpoint
## Environment Variables
Export the required environment vars for connecting to the Kinesis stream:
```shell
export AWS_ACCESS_KEY=
export AWS_REGION=
export AWS_SECRET_KEY=
```
## Run the consumer
go run main.go --app <appName> --stream <streamName> --table <tableName> --connection <connectionString>
Connection string should look something like
user:password@/dbname

View file

@ -0,0 +1,68 @@
package main
import (
"context"
"expvar"
"flag"
"fmt"
"log"
"os"
"os/signal"
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/mysql"
)
func main() {
var (
app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name")
table = flag.String("table", "", "Table name")
connStr = flag.String("connection", "", "Connection Str")
)
flag.Parse()
// mysql checkpoint
ck, err := checkpoint.New(*app, *table, *connStr)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
var counter = expvar.NewMap("counters")
// consumer
c, err := consumer.New(
*stream,
consumer.WithCheckpoint(ck),
consumer.WithCounter(counter),
)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
go func() {
<-signals
cancel()
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data))
return nil
})
if err != nil {
log.Fatalf("scan error: %v", err)
}
if err := ck.Shutdown(); err != nil {
log.Fatalf("checkpoint shutdown error: %v", err)
}
}