diff --git a/Gopkg.lock b/Gopkg.lock index 5e10731..76fcb47 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -57,6 +57,17 @@ revision = "358ee7663966325963d4e8b2e1fbd570c5195153" version = "v1.38.1" +[[projects]] + name = "github.com/harlow/kinesis-consumer" + packages = [ + ".", + "checkpoint/ddb", + "checkpoint/postgres", + "checkpoint/redis" + ] + revision = "049445e259a2ab9146364bf60d6f5f71270a125b" + version = "v0.2.0" + [[projects]] name = "github.com/jmespath/go-jmespath" packages = ["."] @@ -77,6 +88,12 @@ revision = "645ef00459ed84a119197bfb8d8205042c6df63d" version = "v0.8.0" +[[projects]] + name = "gopkg.in/DATA-DOG/go-sqlmock.v1" + packages = ["."] + revision = "d76b18b42f285b792bf985118980ce9eacea9d10" + version = "v1.3.0" + [[projects]] name = "gopkg.in/redis.v5" packages = [ @@ -93,6 +110,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "6b3044ce1b075f919471f2457f32450efaa36518381fd84164641860c296de5a" + inputs-digest = "2588ee54549a76e93e2e65a289fccd8b636f85b124c5ccb0ab3d5f3529a3cbaa" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index db2218d..07fa972 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -44,3 +44,7 @@ [prune] go-tests = true unused-packages = true + +[[constraint]] + name = "gopkg.in/DATA-DOG/go-sqlmock.v1" + version = "1.3.0" diff --git a/checkpoint/postgres/postgres.go b/checkpoint/postgres/postgres.go index a26c58f..b5a5bda 100644 --- a/checkpoint/postgres/postgres.go +++ b/checkpoint/postgres/postgres.go @@ -10,17 +10,6 @@ import ( _ "github.com/lib/pq" ) -var getCheckpointQuery = `SELECT sequence_number - FROM %s - WHERE namespace=$1 AND shard_id=$2` - -var upsertCheckpoint = `INSERT INTO %s (namespace, shard_id, sequence_number) - VALUES($1, $2, $3) - ON CONFLICT (namespace, shard_id) - DO - UPDATE - SET sequence_number= $3` - type key struct { streamName string shardID string @@ -36,9 +25,10 @@ func WithMaxInterval(maxInterval time.Duration) Option { } } -// Checkpoint stores and retreives the last evaluated key from a DDB scan +// Checkpoint stores and retrieves the last evaluated key from a DDB scan type Checkpoint struct { appName string + tableName string conn *sql.DB mu *sync.Mutex // protects the checkpoints done chan struct{} @@ -49,9 +39,12 @@ type Checkpoint struct { // New returns a checkpoint that uses PostgresDB for underlying storage // Using connectionStr turn it more flexible to use specific db configs func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, error) { + if appName == "" { + return nil, errors.New("application name not defined") + } if tableName == "" { - return nil, errors.New("Table name not defined") + return nil, errors.New("table name not defined") } conn, err := sql.Open("postgres", connectionStr) @@ -60,14 +53,12 @@ func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, return nil, err } - getCheckpointQuery = fmt.Sprintf(getCheckpointQuery, tableName) - upsertCheckpoint = fmt.Sprintf(upsertCheckpoint, tableName) - ck := &Checkpoint{ conn: conn, appName: appName, + tableName: tableName, done: make(chan struct{}), - maxInterval: time.Duration(1 * time.Minute), + maxInterval: 1 * time.Minute, mu: new(sync.Mutex), checkpoints: map[key]string{}, } @@ -81,6 +72,11 @@ func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, return ck, nil } +// GetMaxInterval returns the maximum interval before the checkpoint +func (c *Checkpoint) GetMaxInterval() time.Duration { + return c.maxInterval +} + // Get determines if a checkpoint for a particular Shard exists. // Typically used to determine whether we should start processing the shard with // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). @@ -88,14 +84,13 @@ func (c *Checkpoint) Get(streamName, shardID string) (string, error) { namespace := fmt.Sprintf("%s-%s", c.appName, streamName) var sequenceNumber string - + getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) //nolint: gas, it replaces only the table name err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber) if err != nil { if err == sql.ErrNoRows { return "", nil } - return "", err } @@ -150,8 +145,15 @@ func (c *Checkpoint) save() error { c.mu.Lock() defer c.mu.Unlock() - for key, sequenceNumber := range c.checkpoints { + //nolint: gas, it replaces only the table name + upsertCheckpoint := fmt.Sprintf(`INSERT INTO %s (namespace, shard_id, sequence_number) + VALUES($1, $2, $3) + ON CONFLICT (namespace, shard_id) + DO + UPDATE + SET sequence_number= $3;`, c.tableName) + for key, sequenceNumber := range c.checkpoints { if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil { return err } diff --git a/checkpoint/postgres/postgres_databaseutils_test.go b/checkpoint/postgres/postgres_databaseutils_test.go new file mode 100644 index 0000000..d0a86af --- /dev/null +++ b/checkpoint/postgres/postgres_databaseutils_test.go @@ -0,0 +1,7 @@ +package postgres + +import "database/sql" + +func (c *Checkpoint) SetConn(conn *sql.DB) { + c.conn = conn +} diff --git a/checkpoint/postgres/postgres_test.go b/checkpoint/postgres/postgres_test.go new file mode 100644 index 0000000..135fd0d --- /dev/null +++ b/checkpoint/postgres/postgres_test.go @@ -0,0 +1,308 @@ +package postgres_test + +import ( + "testing" + + "time" + + "fmt" + + "database/sql" + + "github.com/harlow/kinesis-consumer/checkpoint/postgres" + "github.com/pkg/errors" + "gopkg.in/DATA-DOG/go-sqlmock.v1" +) + +func TestNew(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + ck, err := postgres.New(appName, tableName, connString) + + if ck == nil { + t.Errorf("expected checkpointer not equal nil, but got %v", ck) + } + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + ck.Shutdown() +} + +func TestNew_AppNameEmpty(t *testing.T) { + appName := "" + tableName := "checkpoint" + connString := "" + ck, err := postgres.New(appName, tableName, connString) + + if ck != nil { + t.Errorf("expected checkpointer equal nil, but got %v", ck) + } + if err == nil { + t.Errorf("expected error equals not nil, but got %v", err) + } +} + +func TestNew_TableNameEmpty(t *testing.T) { + appName := "streamConsumer" + tableName := "" + connString := "" + ck, err := postgres.New(appName, tableName, connString) + + if ck != nil { + t.Errorf("expected checkpointer equal nil, but got %v", ck) + } + if err == nil { + t.Errorf("expected error equals not nil, but got %v", err) + } +} + +func TestNew_WithMaxIntervalOption(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + maxInterval := time.Second + ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval)) + + if ck == nil { + t.Errorf("expected checkpointer not equal nil, but got %v", ck) + } + if ck.GetMaxInterval() != time.Second { + t.Errorf("expected max interval equals %v, but got %v", maxInterval, ck.GetMaxInterval()) + } + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Get(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + rows := []string{"sequence_number"} + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedRows := sqlmock.NewRows(rows) + expectedRows.AddRow(expectedSequenceNumber) + expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\$1 AND shard_id=\$2;`, + tableName) + mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows) + + gotSequenceNumber, err := ck.Get(streamName, shardID) + + if gotSequenceNumber != expectedSequenceNumber { + t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber) + } + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Get_NoRows(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + streamName := "myStreamName" + shardID := "shardId-00000000" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\$1 AND shard_id=\$2;`, + tableName) + mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows) + + gotSequenceNumber, err := ck.Get(streamName, shardID) + + if gotSequenceNumber != "" { + t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) + } + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Get_QueryError(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + streamName := "myStreamName" + shardID := "shardId-00000000" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedSQLRegexString := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=\$1 AND shard_id=\$2;`, + tableName) + mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error")) + + gotSequenceNumber, err := ck.Get(streamName, shardID) + + if gotSequenceNumber != "" { + t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) + } + if err == nil { + t.Errorf("expected error equals not nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Set(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194" + maxInterval := time.Second + ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + + err = ck.Set(streamName, shardID, expectedSequenceNumber) + + if err != nil { + t.Errorf("expected error equals nil, but got %v", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "" + maxInterval := time.Second + ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + + err = ck.Set(streamName, shardID, expectedSequenceNumber) + + if err == nil { + t.Errorf("expected error equals not nil, but got %v", err) + } + ck.Shutdown() +} + +func TestCheckpoint_Shutdown(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName) + result := sqlmock.NewResult(0, 1) + mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result) + + err = ck.Set(streamName, shardID, expectedSequenceNumber) + + if err != nil { + t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) + } + + err = ck.Shutdown() + + if err != nil { + t.Errorf("expected error equals not nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestCheckpoint_Shutdown_SaveError(t *testing.T) { + appName := "streamConsumer" + tableName := "checkpoint" + connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" + streamName := "myStreamName" + shardID := "shardId-00000000" + expectedSequenceNumber := "49578481031144599192696750682534686652010819674221576194" + maxInterval := time.Second + connMock, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("error occurred during the sqlmock creation. cause: %v", err) + } + ck, err := postgres.New(appName, tableName, connString, postgres.WithMaxInterval(maxInterval)) + if err != nil { + t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) + } + ck.SetConn(connMock) // nolint: gotypex, the function available only in test + + namespace := fmt.Sprintf("%s-%s", appName, streamName) + expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName) + mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error")) + + err = ck.Set(streamName, shardID, expectedSequenceNumber) + + if err != nil { + t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) + } + + err = ck.Shutdown() + + if err == nil { + t.Errorf("expected error equals nil, but got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +}