2018-06-18 02:27:10 +00:00
package postgres
import (
2018-11-07 23:45:13 +00:00
"context"
2018-06-18 02:27:10 +00:00
"database/sql"
"errors"
"fmt"
"sync"
"time"
// this is the postgres package so it makes sense to be here
_ "github.com/lib/pq"
)
type key struct {
streamName string
shardID string
}
// Option is used to override defaults when creating a new Checkpoint
type Option func ( * Checkpoint )
2018-07-13 14:26:21 +00:00
// WithMaxInterval sets the flush interval
func WithMaxInterval ( maxInterval time . Duration ) Option {
return func ( c * Checkpoint ) {
c . maxInterval = maxInterval
}
}
2018-10-14 17:23:37 +00:00
// Checkpoint stores and retrieves the last evaluated key from a DDB scan
2018-06-18 02:27:10 +00:00
type Checkpoint struct {
appName string
2018-10-14 17:23:37 +00:00
tableName string
2018-06-18 02:27:10 +00:00
conn * sql . DB
mu * sync . Mutex // protects the checkpoints
done chan struct { }
checkpoints map [ key ] string
maxInterval time . Duration
}
// New returns a checkpoint that uses PostgresDB for underlying storage
// Using connectionStr turn it more flexible to use specific db configs
func New ( appName , tableName , connectionStr string , opts ... Option ) ( * Checkpoint , error ) {
2018-10-14 17:23:37 +00:00
if appName == "" {
return nil , errors . New ( "application name not defined" )
}
2018-06-18 02:27:10 +00:00
if tableName == "" {
2018-10-14 17:23:37 +00:00
return nil , errors . New ( "table name not defined" )
2018-06-18 02:27:10 +00:00
}
conn , err := sql . Open ( "postgres" , connectionStr )
if err != nil {
return nil , err
}
ck := & Checkpoint {
conn : conn ,
appName : appName ,
2018-10-14 17:23:37 +00:00
tableName : tableName ,
2018-06-18 02:27:10 +00:00
done : make ( chan struct { } ) ,
2018-10-14 17:23:37 +00:00
maxInterval : 1 * time . Minute ,
2018-06-18 02:27:10 +00:00
mu : new ( sync . Mutex ) ,
checkpoints : map [ key ] string { } ,
}
for _ , opt := range opts {
opt ( ck )
}
go ck . loop ( )
return ck , nil
}
2018-10-14 17:23:37 +00:00
// GetMaxInterval returns the maximum interval before the checkpoint
func ( c * Checkpoint ) GetMaxInterval ( ) time . Duration {
return c . maxInterval
}
2018-06-18 02:27:10 +00:00
// Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
2018-11-07 23:45:13 +00:00
func ( c * Checkpoint ) Get ( ctx context . Context , streamName , shardID string ) ( string , error ) {
2018-06-18 02:27:10 +00:00
namespace := fmt . Sprintf ( "%s-%s" , c . appName , streamName )
var sequenceNumber string
2018-10-14 17:23:37 +00:00
getCheckpointQuery := fmt . Sprintf ( ` SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2; ` , c . tableName ) //nolint: gas, it replaces only the table name
2018-06-18 02:27:10 +00:00
err := c . conn . QueryRow ( getCheckpointQuery , namespace , shardID ) . Scan ( & sequenceNumber )
if err != nil {
if err == sql . ErrNoRows {
return "" , nil
}
return "" , err
}
return sequenceNumber , nil
}
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
2018-11-07 23:45:13 +00:00
func ( c * Checkpoint ) Set ( ctx context . Context , streamName , shardID , sequenceNumber string ) error {
2018-06-18 02:27:10 +00:00
c . mu . Lock ( )
defer c . mu . Unlock ( )
if sequenceNumber == "" {
return fmt . Errorf ( "sequence number should not be empty" )
}
key := key {
streamName : streamName ,
shardID : shardID ,
}
c . checkpoints [ key ] = sequenceNumber
return nil
}
// Shutdown the checkpoint. Save any in-flight data.
func ( c * Checkpoint ) Shutdown ( ) error {
defer c . conn . Close ( )
c . done <- struct { } { }
return c . save ( )
}
func ( c * Checkpoint ) loop ( ) {
tick := time . NewTicker ( c . maxInterval )
defer tick . Stop ( )
defer close ( c . done )
for {
select {
case <- tick . C :
c . save ( )
case <- c . done :
return
}
}
}
func ( c * Checkpoint ) save ( ) error {
c . mu . Lock ( )
defer c . mu . Unlock ( )
2018-10-14 17:23:37 +00:00
//nolint: gas, it replaces only the table name
upsertCheckpoint := fmt . Sprintf ( ` INSERT INTO % s ( namespace , shard_id , sequence_number )
VALUES ( $ 1 , $ 2 , $ 3 )
ON CONFLICT ( namespace , shard_id )
DO
UPDATE
SET sequence_number = $ 3 ; ` , c . tableName )
2018-06-18 02:27:10 +00:00
2018-10-14 17:23:37 +00:00
for key , sequenceNumber := range c . checkpoints {
2018-06-18 02:27:10 +00:00
if _ , err := c . conn . Exec ( upsertCheckpoint , fmt . Sprintf ( "%s-%s" , c . appName , key . streamName ) , key . shardID , sequenceNumber ) ; err != nil {
return err
}
}
return nil
}