As we work towards introducing consumer groups to the repository we need a more generic name for the persistence layer for storing checkpoints and leases for given shards. * Rename `checkpoint` to `store`
158 lines
3.7 KiB
Go
158 lines
3.7 KiB
Go
package mysql
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
type key struct {
|
|
streamName string
|
|
shardID string
|
|
}
|
|
|
|
// Option is used to override defaults when creating a new Checkpoint
|
|
type Option func(*Checkpoint)
|
|
|
|
// WithMaxInterval sets the flush interval
|
|
func WithMaxInterval(maxInterval time.Duration) Option {
|
|
return func(c *Checkpoint) {
|
|
c.maxInterval = maxInterval
|
|
}
|
|
}
|
|
|
|
// Checkpoint stores and retrieves the last evaluated key from a DDB scan
|
|
type Checkpoint struct {
|
|
appName string
|
|
tableName string
|
|
conn *sql.DB
|
|
mu *sync.Mutex // protects the checkpoints
|
|
done chan struct{}
|
|
checkpoints map[key]string
|
|
maxInterval time.Duration
|
|
}
|
|
|
|
// New returns a checkpoint that uses Mysql for underlying storage
|
|
// Using connectionStr turn it more flexible to use specific db configs
|
|
func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, error) {
|
|
if appName == "" {
|
|
return nil, errors.New("application name not defined")
|
|
}
|
|
|
|
if tableName == "" {
|
|
return nil, errors.New("table name not defined")
|
|
}
|
|
|
|
conn, err := sql.Open("mysql", connectionStr)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ck := &Checkpoint{
|
|
conn: conn,
|
|
appName: appName,
|
|
tableName: tableName,
|
|
done: make(chan struct{}),
|
|
maxInterval: 1 * time.Minute,
|
|
mu: new(sync.Mutex),
|
|
checkpoints: map[key]string{},
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(ck)
|
|
}
|
|
|
|
go ck.loop()
|
|
|
|
return ck, nil
|
|
}
|
|
|
|
// GetMaxInterval returns the maximum interval before the checkpoint
|
|
func (c *Checkpoint) GetMaxInterval() time.Duration {
|
|
return c.maxInterval
|
|
}
|
|
|
|
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
|
|
// Typically used to determine whether we should start processing the shard with
|
|
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
|
|
func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
|
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
|
|
|
|
var sequenceNumber string
|
|
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=? AND shard_id=?;`, c.tableName) //nolint: gas, it replaces only the table name
|
|
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return "", nil
|
|
}
|
|
return "", err
|
|
}
|
|
|
|
return sequenceNumber, nil
|
|
}
|
|
|
|
// SetCheckpoint stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
|
|
// Upon failover, record processing is resumed from this point.
|
|
func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if sequenceNumber == "" {
|
|
return fmt.Errorf("sequence number should not be empty")
|
|
}
|
|
|
|
key := key{
|
|
streamName: streamName,
|
|
shardID: shardID,
|
|
}
|
|
|
|
c.checkpoints[key] = sequenceNumber
|
|
|
|
return nil
|
|
}
|
|
|
|
// Shutdown the checkpoint. Save any in-flight data.
|
|
func (c *Checkpoint) Shutdown() error {
|
|
defer c.conn.Close()
|
|
|
|
c.done <- struct{}{}
|
|
|
|
return c.save()
|
|
}
|
|
|
|
func (c *Checkpoint) loop() {
|
|
tick := time.NewTicker(c.maxInterval)
|
|
defer tick.Stop()
|
|
defer close(c.done)
|
|
|
|
for {
|
|
select {
|
|
case <-tick.C:
|
|
c.save()
|
|
case <-c.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Checkpoint) save() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
//nolint: gas, it replaces only the table name
|
|
upsertCheckpoint := fmt.Sprintf(`REPLACE INTO %s (namespace, shard_id, sequence_number) VALUES (?, ?, ?)`, c.tableName)
|
|
|
|
for key, sequenceNumber := range c.checkpoints {
|
|
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|