Write unit tests for Postgres package (#69)
Changes: * I add postgres_databaseutils_test.go for containing all utilities for mocking the database. * I put appName as a required parameter passed to the New() since it is needed to form the namespace column name. * I add GetMaxInterval() since it might be needed by the consumer and the test. * I move the SQL string package variables directly into the function that need it so we could avoid maintenance nightmare in the future.
This commit is contained in:
parent
d3b76346f5
commit
cb35697903
5 changed files with 359 additions and 21 deletions
19
Gopkg.lock
generated
19
Gopkg.lock
generated
|
|
@ -57,6 +57,17 @@
|
|||
revision = "358ee7663966325963d4e8b2e1fbd570c5195153"
|
||||
version = "v1.38.1"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/harlow/kinesis-consumer"
|
||||
packages = [
|
||||
".",
|
||||
"checkpoint/ddb",
|
||||
"checkpoint/postgres",
|
||||
"checkpoint/redis"
|
||||
]
|
||||
revision = "049445e259a2ab9146364bf60d6f5f71270a125b"
|
||||
version = "v0.2.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/jmespath/go-jmespath"
|
||||
packages = ["."]
|
||||
|
|
@ -77,6 +88,12 @@
|
|||
revision = "645ef00459ed84a119197bfb8d8205042c6df63d"
|
||||
version = "v0.8.0"
|
||||
|
||||
[[projects]]
|
||||
name = "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
||||
packages = ["."]
|
||||
revision = "d76b18b42f285b792bf985118980ce9eacea9d10"
|
||||
version = "v1.3.0"
|
||||
|
||||
[[projects]]
|
||||
name = "gopkg.in/redis.v5"
|
||||
packages = [
|
||||
|
|
@ -93,6 +110,6 @@
|
|||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "6b3044ce1b075f919471f2457f32450efaa36518381fd84164641860c296de5a"
|
||||
inputs-digest = "2588ee54549a76e93e2e65a289fccd8b636f85b124c5ccb0ab3d5f3529a3cbaa"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
|
|
|||
|
|
@ -44,3 +44,7 @@
|
|||
[prune]
|
||||
go-tests = true
|
||||
unused-packages = true
|
||||
|
||||
[[constraint]]
|
||||
name = "gopkg.in/DATA-DOG/go-sqlmock.v1"
|
||||
version = "1.3.0"
|
||||
|
|
|
|||
|
|
@ -10,17 +10,6 @@ import (
|
|||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
var getCheckpointQuery = `SELECT sequence_number
|
||||
FROM %s
|
||||
WHERE namespace=$1 AND shard_id=$2`
|
||||
|
||||
var upsertCheckpoint = `INSERT INTO %s (namespace, shard_id, sequence_number)
|
||||
VALUES($1, $2, $3)
|
||||
ON CONFLICT (namespace, shard_id)
|
||||
DO
|
||||
UPDATE
|
||||
SET sequence_number= $3`
|
||||
|
||||
type key struct {
|
||||
streamName string
|
||||
shardID string
|
||||
|
|
@ -36,9 +25,10 @@ func WithMaxInterval(maxInterval time.Duration) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// Checkpoint stores and retreives the last evaluated key from a DDB scan
|
||||
// Checkpoint stores and retrieves the last evaluated key from a DDB scan
|
||||
type Checkpoint struct {
|
||||
appName string
|
||||
tableName string
|
||||
conn *sql.DB
|
||||
mu *sync.Mutex // protects the checkpoints
|
||||
done chan struct{}
|
||||
|
|
@ -49,9 +39,12 @@ type Checkpoint struct {
|
|||
// New returns a checkpoint that uses PostgresDB for underlying storage
|
||||
// Using connectionStr turn it more flexible to use specific db configs
|
||||
func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, error) {
|
||||
if appName == "" {
|
||||
return nil, errors.New("application name not defined")
|
||||
}
|
||||
|
||||
if tableName == "" {
|
||||
return nil, errors.New("Table name not defined")
|
||||
return nil, errors.New("table name not defined")
|
||||
}
|
||||
|
||||
conn, err := sql.Open("postgres", connectionStr)
|
||||
|
|
@ -60,14 +53,12 @@ func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
getCheckpointQuery = fmt.Sprintf(getCheckpointQuery, tableName)
|
||||
upsertCheckpoint = fmt.Sprintf(upsertCheckpoint, tableName)
|
||||
|
||||
ck := &Checkpoint{
|
||||
conn: conn,
|
||||
appName: appName,
|
||||
tableName: tableName,
|
||||
done: make(chan struct{}),
|
||||
maxInterval: time.Duration(1 * time.Minute),
|
||||
maxInterval: 1 * time.Minute,
|
||||
mu: new(sync.Mutex),
|
||||
checkpoints: map[key]string{},
|
||||
}
|
||||
|
|
@ -81,6 +72,11 @@ func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint,
|
|||
return ck, nil
|
||||
}
|
||||
|
||||
// GetMaxInterval returns the maximum interval before the checkpoint
|
||||
func (c *Checkpoint) GetMaxInterval() time.Duration {
|
||||
return c.maxInterval
|
||||
}
|
||||
|
||||
// Get determines if a checkpoint for a particular Shard exists.
|
||||
// Typically used to determine whether we should start processing the shard with
|
||||
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
|
||||
|
|
@ -88,14 +84,13 @@ func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
|
|||
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
|
||||
|
||||
var sequenceNumber string
|
||||
|
||||
getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) //nolint: gas, it replaces only the table name
|
||||
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
|
@ -150,8 +145,15 @@ func (c *Checkpoint) save() error {
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for key, sequenceNumber := range c.checkpoints {
|
||||
//nolint: gas, it replaces only the table name
|
||||
upsertCheckpoint := fmt.Sprintf(`INSERT INTO %s (namespace, shard_id, sequence_number)
|
||||
VALUES($1, $2, $3)
|
||||
ON CONFLICT (namespace, shard_id)
|
||||
DO
|
||||
UPDATE
|
||||
SET sequence_number= $3;`, c.tableName)
|
||||
|
||||
for key, sequenceNumber := range c.checkpoints {
|
||||
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
7
checkpoint/postgres/postgres_databaseutils_test.go
Normal file
7
checkpoint/postgres/postgres_databaseutils_test.go
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
package postgres
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func (c *Checkpoint) SetConn(conn *sql.DB) {
|
||||
c.conn = conn
|
||||
}
|
||||
308
checkpoint/postgres/postgres_test.go
Normal file
308
checkpoint/postgres/postgres_test.go
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
package postgres_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"database/sql"
|
||||
|
||||
"github.com/harlow/kinesis-consumer/checkpoint/postgres"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/DATA-DOG/go-sqlmock.v1"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
ck, err := postgres.New(appName, tableName, connString)
|
||||
|
||||
if ck == nil {
|
||||
t.Errorf("expected checkpointer not equal nil, but got %v", ck)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
ck.Shutdown()
|
||||
}
|
||||
|
||||
func TestNew_AppNameEmpty(t *testing.T) {
|
||||
appName := ""
|
||||
tableName := "checkpoint"
|
||||
connString := ""
|
||||
ck, err := postgres.New(appName, tableName, connString)
|
||||
|
||||
if ck != nil {
|
||||
t.Errorf("expected checkpointer equal nil, but got %v", ck)
|
||||
}
|
||||
if err == nil {
|
||||
t.Errorf("expected error equals not nil, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_TableNameEmpty(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := ""
|
||||
connString := ""
|
||||
ck, err := postgres.New(appName, tableName, connString)
|
||||
|
||||
if ck != nil {
|
||||
t.Errorf("expected checkpointer equal nil, but got %v", ck)
|
||||
}
|
||||
if err == nil {
|
||||
t.Errorf("expected error equals not nil, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_WithMaxIntervalOption(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
maxInterval := time.Second
|
||||
ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval))
|
||||
|
||||
if ck == nil {
|
||||
t.Errorf("expected checkpointer not equal nil, but got %v", ck)
|
||||
}
|
||||
if ck.GetMaxInterval() != time.Second {
|
||||
t.Errorf("expected max interval equals %v, but got %v", maxInterval, ck.GetMaxInterval())
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
ck.Shutdown()
|
||||
}
|
||||
|
||||
func TestCheckpoint_Get(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
streamName := "myStreamName"
|
||||
shardID := "shardId-00000000"
|
||||
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
|
||||
maxInterval := time.Second
|
||||
connMock, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
|
||||
}
|
||||
ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval))
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
|
||||
|
||||
rows := []string{"sequence_number"}
|
||||
namespace := fmt.Sprintf("%s-%s", appName, streamName)
|
||||
expectedRows := sqlmock.NewRows(rows)
|
||||
expectedRows.AddRow(expectedSequenceNumber)
|
||||
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\$1 AND shard_id=\$2;`,
|
||||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
|
||||
|
||||
gotSequenceNumber, err := ck.Get(streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != expectedSequenceNumber {
|
||||
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
ck.Shutdown()
|
||||
}
|
||||
|
||||
func TestCheckpoint_Get_NoRows(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
streamName := "myStreamName"
|
||||
shardID := "shardId-00000000"
|
||||
maxInterval := time.Second
|
||||
connMock, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
|
||||
}
|
||||
ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval))
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
|
||||
|
||||
namespace := fmt.Sprintf("%s-%s", appName, streamName)
|
||||
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\$1 AND shard_id=\$2;`,
|
||||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
|
||||
|
||||
gotSequenceNumber, err := ck.Get(streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != "" {
|
||||
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
ck.Shutdown()
|
||||
}
|
||||
|
||||
func TestCheckpoint_Get_QueryError(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
streamName := "myStreamName"
|
||||
shardID := "shardId-00000000"
|
||||
maxInterval := time.Second
|
||||
connMock, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
|
||||
}
|
||||
ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval))
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
|
||||
|
||||
namespace := fmt.Sprintf("%s-%s", appName, streamName)
|
||||
expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\$1 AND shard_id=\$2;`,
|
||||
tableName)
|
||||
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
|
||||
|
||||
gotSequenceNumber, err := ck.Get(streamName, shardID)
|
||||
|
||||
if gotSequenceNumber != "" {
|
||||
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
|
||||
}
|
||||
if err == nil {
|
||||
t.Errorf("expected error equals not nil, but got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
ck.Shutdown()
|
||||
}
|
||||
|
||||
func TestCheckpoint_Set(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
streamName := "myStreamName"
|
||||
shardID := "shardId-00000000"
|
||||
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
|
||||
maxInterval := time.Second
|
||||
ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval))
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.Set(streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
ck.Shutdown()
|
||||
}
|
||||
|
||||
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
streamName := "myStreamName"
|
||||
shardID := "shardId-00000000"
|
||||
expectedSequenceNumber := ""
|
||||
maxInterval := time.Second
|
||||
ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval))
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.Set(streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("expected error equals not nil, but got %v", err)
|
||||
}
|
||||
ck.Shutdown()
|
||||
}
|
||||
|
||||
func TestCheckpoint_Shutdown(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
streamName := "myStreamName"
|
||||
shardID := "shardId-00000000"
|
||||
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
|
||||
maxInterval := time.Second
|
||||
connMock, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
|
||||
}
|
||||
ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval))
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
|
||||
|
||||
namespace := fmt.Sprintf("%s-%s", appName, streamName)
|
||||
expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName)
|
||||
result := sqlmock.NewResult(0, 1)
|
||||
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
|
||||
|
||||
err = ck.Set(streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.Shutdown()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected error equals not nil, but got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
|
||||
appName := "streamConsumer"
|
||||
tableName := "checkpoint"
|
||||
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
|
||||
streamName := "myStreamName"
|
||||
shardID := "shardId-00000000"
|
||||
expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194"
|
||||
maxInterval := time.Second
|
||||
connMock, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the sqlmock creation. cause: %v", err)
|
||||
}
|
||||
ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval))
|
||||
if err != nil {
|
||||
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
|
||||
}
|
||||
ck.SetConn(connMock) // nolint: gotypex, the function available only in test
|
||||
|
||||
namespace := fmt.Sprintf("%s-%s", appName, streamName)
|
||||
expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName)
|
||||
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
|
||||
|
||||
err = ck.Set(streamName, shardID, expectedSequenceNumber)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
|
||||
}
|
||||
|
||||
err = ck.Shutdown()
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("expected error equals nil, but got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue