Add Mysql support for checkpointing (#87)

This commit is contained in:
James Greenhill 2019-04-12 22:15:49 -07:00 committed by Harlow Ward
parent f7f98a4bc6
commit b48acfa5d4
6 changed files with 586 additions and 0 deletions

View file

@ -197,6 +197,34 @@ CREATE TABLE kinesis_consumer (
The table name has to be the same that you specify when creating the checkpoint. The primary key composed by namespace and shard_id is mandatory in order to the checkpoint run without issues and also to ensure data integrity. The table name has to be the same that you specify when creating the checkpoint. The primary key composed by namespace and shard_id is mandatory in order to the checkpoint run without issues and also to ensure data integrity.
### Mysql Checkpoint
The Mysql checkpoint requires Table Name, App Name, Stream Name and ConnectionString (just like the Postgres checkpoint!):
```go
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/mysql"
// mysql checkpoint
ck, err := checkpoint.New(app, table, connStr)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
```
To leverage the Mysql checkpoint we'll also need to create a table:
```sql
CREATE TABLE kinesis_consumer (
namespace varchar(255) NOT NULL,
shard_id varchar(255) NOT NULL,
sequence_number numeric(65,0) NOT NULL,
CONSTRAINT kinesis_consumer_pk PRIMARY KEY (namespace, shard_id)
);
```
The table name has to be the same that you specify when creating the checkpoint. The primary key composed by namespace and shard_id is mandatory in order to the checkpoint run without issues and also to ensure data integrity.
## Options ## Options
The consumer allows the following optional overrides. The consumer allows the following optional overrides.

158
checkpoint/mysql/mysql.go Normal file
View file

@ -0,0 +1,158 @@
package mysql
import (
"database/sql"
"errors"
"fmt"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
)
type key struct {
streamName string
shardID string
}
// Option is used to override defaults when creating a new Checkpoint
type Option func(*Checkpoint)
// WithMaxInterval sets the flush interval
func WithMaxInterval(maxInterval time.Duration) Option {
return func(c *Checkpoint) {
c.maxInterval = maxInterval
}
}
// Checkpoint stores and retrieves the last evaluated key from a DDB scan
type Checkpoint struct {
appName string
tableName string
conn *sql.DB
mu *sync.Mutex // protects the checkpoints
done chan struct{}
checkpoints map[key]string
maxInterval time.Duration
}
// New returns a checkpoint that uses Mysql for underlying storage
// Using connectionStr turn it more flexible to use specific db configs
func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, error) {
if appName == "" {
return nil, errors.New("application name not defined")
}
if tableName == "" {
return nil, errors.New("table name not defined")
}
conn, err := sql.Open("mysql", connectionStr)
if err != nil {
return nil, err
}
ck := &Checkpoint{
conn: conn,
appName: appName,
tableName: tableName,
done: make(chan struct{}),
maxInterval: 1 * time.Minute,
mu: new(sync.Mutex),
checkpoints: map[key]string{},
}
for _, opt := range opts {
opt(ck)
}
go ck.loop()
return ck, nil
}
// GetMaxInterval returns the maximum interval before the checkpoint
func (c *Checkpoint) GetMaxInterval() time.Duration {
return c.maxInterval
}
// Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) //nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil {
if err == sql.ErrNoRows {
return "", nil
}
return "", err
}
return sequenceNumber, nil
}
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
c.mu.Lock()
defer c.mu.Unlock()
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
key := key{
streamName: streamName,
shardID: shardID,
}
c.checkpoints[key] = sequenceNumber
return nil
}
// Shutdown the checkpoint. Save any in-flight data.
func (c *Checkpoint) Shutdown() error {
defer c.conn.Close()
c.done <- struct{}{}
return c.save()
}
func (c *Checkpoint) loop() {
tick := time.NewTicker(c.maxInterval)
defer tick.Stop()
defer close(c.done)
for {
select {
case <-tick.C:
c.save()
case <-c.done:
return
}
}
}
func (c *Checkpoint) save() error {
c.mu.Lock()
defer c.mu.Unlock()
//nolint: gas, it replaces only the table name
upsertCheckpoint := fmt.Sprintf(`REPLACE INTO %s (namespace, shard_id, sequence_number) VALUES (?, ?, ?)`, c.tableName)
for key, sequenceNumber := range c.checkpoints {
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,7 @@
package mysql
import "database/sql"
func (c *Checkpoint) SetConn(conn *sql.DB) {
c.conn = conn
}

View file

@ -0,0 +1,304 @@
package mysql
import (
"database/sql"
"fmt"
"testing"
"time"
"github.com/pkg/errors"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
func TestNew(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
ck, err := New(appName, tableName, connString)
if ck == nil {
t.Errorf("expected checkpointer not equal nil, but got %v", ck)
}
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
ck.Shutdown()
}
func TestNew_AppNameEmpty(t *testing.T) {
appName := ""
tableName := "checkpoint"
connString := ""
ck, err := New(appName, tableName, connString)
if ck != nil {
t.Errorf("expected checkpointer equal nil, but got %v", ck)
}
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
}
func TestNew_TableNameEmpty(t *testing.T) {
appName := "streamConsumer"
tableName := ""
connString := ""
ck, err := New(appName, tableName, connString)
if ck != nil {
t.Errorf("expected checkpointer equal nil, but got %v", ck)
}
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
}
func TestNew_WithMaxIntervalOption(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
maxInterval := time.Second
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if ck == nil {
t.Errorf("expected checkpointer not equal nil, but got %v", ck)
}
if ck.GetMaxInterval() != time.Second {
t.Errorf("expected max interval equals %v, but got %v", maxInterval, ck.GetMaxInterval())
}
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
ck.Shutdown()
}
func TestCheckpoint_Get(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
rows := []string{"sequence_number"}
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedRows := sqlmock.NewRows(rows)
expectedRows.AddRow(expectedSequenceNumber)
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`,
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
gotSequenceNumber, err := ck.Get(streamName, shardID)
if gotSequenceNumber != expectedSequenceNumber {
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
}
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
ck.Shutdown()
}
func TestCheckpoint_Get_NoRows(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`,
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
gotSequenceNumber, err := ck.Get(streamName, shardID)
if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
}
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
ck.Shutdown()
}
func TestCheckpoint_Get_QueryError(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\? AND shard_id=\?;`,
tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
gotSequenceNumber, err := ck.Get(streamName, shardID)
if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
}
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
ck.Shutdown()
}
func TestCheckpoint_Set(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
maxInterval := time.Second
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
err = ck.Set(streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Errorf("expected error equals nil, but got %v", err)
}
ck.Shutdown()
}
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := ""
maxInterval := time.Second
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
err = ck.Set(streamName, shardID, expectedSequenceNumber)
if err == nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
ck.Shutdown()
}
func TestCheckpoint_Shutdown(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName)
result := sqlmock.NewResult(0, 1)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
err = ck.Set(streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
}
err = ck.Shutdown()
if err != nil {
t.Errorf("expected error equals not nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
appName := "streamConsumer"
tableName := "checkpoint"
connString := "user:password@/dbname"
streamName := "myStreamName"
shardID := "shardId-00000000"
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
maxInterval := time.Second
connMock, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
}
ck, err := New(appName, tableName, connString, WithMaxInterval(maxInterval))
if err != nil {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
}
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
namespace := fmt.Sprintf("%s-%s", appName, streamName)
expectedSQLRegexString := fmt.Sprintf(`REPLACE INTO %s \(namespace, shard_id, sequence_number\) VALUES \(\?, \?, \?\)`, tableName)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
err = ck.Set(streamName, shardID, expectedSequenceNumber)
if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
}
err = ck.Shutdown()
if err == nil {
t.Errorf("expected error equals nil, but got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

View file

@ -0,0 +1,21 @@
# Consumer with mysl checkpoint
Read records from the Kinesis stream using mysql as checkpoint
## Environment Variables
Export the required environment vars for connecting to the Kinesis stream:
```shell
export AWS_ACCESS_KEY=
export AWS_REGION=
export AWS_SECRET_KEY=
```
## Run the consumer
go run main.go --app <appName> --stream <streamName> --table <tableName> --connection <connectionString>
Connection string should look something like
user:password@/dbname

View file

@ -0,0 +1,68 @@
package main
import (
"context"
"expvar"
"flag"
"fmt"
"log"
"os"
"os/signal"
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/mysql"
)
func main() {
var (
app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name")
table = flag.String("table", "", "Table name")
connStr = flag.String("connection", "", "Connection Str")
)
flag.Parse()
// mysql checkpoint
ck, err := checkpoint.New(*app, *table, *connStr)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
var counter = expvar.NewMap("counters")
// consumer
c, err := consumer.New(
*stream,
consumer.WithCheckpoint(ck),
consumer.WithCounter(counter),
)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
go func() {
<-signals
cancel()
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) error {
fmt.Println(string(r.Data))
return nil
})
if err != nil {
log.Fatalf("scan error: %v", err)
}
if err := ck.Shutdown(); err != nil {
log.Fatalf("checkpoint shutdown error: %v", err)
}
}