need to add context as first parameter to pass through, fixed tests and readme

This commit is contained in:
Edward Tsang 2018-11-07 15:45:13 -08:00
parent 0528067bbe
commit a6eaeb2bfa
13 changed files with 253 additions and 105 deletions

118
Gopkg.lock generated
View file

@ -2,15 +2,18 @@
[[projects]] [[projects]]
digest = "1:5bbabe0c3c7e7f524b4c38193b80bf24624e67c0f3a036c4244c85c9a80579fd"
name = "github.com/apex/log" name = "github.com/apex/log"
packages = [ packages = [
".", ".",
"handlers/text" "handlers/text",
] ]
pruneopts = "UT"
revision = "0296d6eb16bb28f8a0c55668affcf4876dc269be" revision = "0296d6eb16bb28f8a0c55668affcf4876dc269be"
version = "v1.0.0" version = "v1.0.0"
[[projects]] [[projects]]
digest = "1:430a0049ba9e5652a778f1bb2a755b456ef8de588d94093f0b02a63cb885fbca"
name = "github.com/aws/aws-sdk-go" name = "github.com/aws/aws-sdk-go"
packages = [ packages = [
"aws", "aws",
@ -46,55 +49,118 @@
"service/dynamodb/dynamodbiface", "service/dynamodb/dynamodbiface",
"service/kinesis", "service/kinesis",
"service/kinesis/kinesisiface", "service/kinesis/kinesisiface",
"service/sts" "service/sts",
] ]
pruneopts = "UT"
revision = "8475c414b1bd58b8cc214873a8854e3a621e67d7" revision = "8475c414b1bd58b8cc214873a8854e3a621e67d7"
version = "v1.15.0" version = "v1.15.0"
[[projects]] [[projects]]
branch = "master"
digest = "1:4c4c33075b704791d6a7f09dfb55c66769e8a1dc6adf87026292d274fe8ad113"
name = "github.com/codahale/hdrhistogram"
packages = ["."]
pruneopts = "UT"
revision = "3a0bb77429bd3a61596f5e8a3172445844342120"
[[projects]]
digest = "1:fe8a03a8222d5b913f256972933d26d24ad7c8286692a42943bc01633cc8fce3"
name = "github.com/go-ini/ini" name = "github.com/go-ini/ini"
packages = ["."] packages = ["."]
pruneopts = "UT"
revision = "358ee7663966325963d4e8b2e1fbd570c5195153" revision = "358ee7663966325963d4e8b2e1fbd570c5195153"
version = "v1.38.1" version = "v1.38.1"
[[projects]] [[projects]]
name = "github.com/harlow/kinesis-consumer" digest = "1:e22af8c7518e1eab6f2eab2b7d7558927f816262586cd6ed9f349c97a6c285c4"
packages = [
".",
"checkpoint/ddb",
"checkpoint/postgres",
"checkpoint/redis"
]
revision = "049445e259a2ab9146364bf60d6f5f71270a125b"
version = "v0.2.0"
[[projects]]
name = "github.com/jmespath/go-jmespath" name = "github.com/jmespath/go-jmespath"
packages = ["."] packages = ["."]
pruneopts = "UT"
revision = "0b12d6b5" revision = "0b12d6b5"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:37ce7d7d80531b227023331002c0d42b4b4b291a96798c82a049d03a54ba79e4"
name = "github.com/lib/pq" name = "github.com/lib/pq"
packages = [ packages = [
".", ".",
"oid" "oid",
] ]
pruneopts = "UT"
revision = "90697d60dd844d5ef6ff15135d0203f65d2f53b8" revision = "90697d60dd844d5ef6ff15135d0203f65d2f53b8"
[[projects]] [[projects]]
digest = "1:450b7623b185031f3a456801155c8320209f75d0d4c4e633c6b1e59d44d6e392"
name = "github.com/opentracing/opentracing-go"
packages = [
".",
"ext",
"log",
]
pruneopts = "UT"
revision = "1949ddbfd147afd4d964a9f00b24eb291e0e7c38"
version = "v1.0.2"
[[projects]]
digest = "1:40e195917a951a8bf867cd05de2a46aaf1806c50cf92eebf4c16f78cd196f747"
name = "github.com/pkg/errors" name = "github.com/pkg/errors"
packages = ["."] packages = ["."]
pruneopts = "UT"
revision = "645ef00459ed84a119197bfb8d8205042c6df63d" revision = "645ef00459ed84a119197bfb8d8205042c6df63d"
version = "v0.8.0" version = "v0.8.0"
[[projects]] [[projects]]
digest = "1:ac6f26e917fd2fb3194a7ebe2baf6fb32de2f2fbfed130c18aac0e758a6e1d22"
name = "github.com/uber/jaeger-client-go"
packages = [
".",
"config",
"internal/baggage",
"internal/baggage/remote",
"internal/spanlog",
"internal/throttler",
"internal/throttler/remote",
"log",
"rpcmetrics",
"thrift",
"thrift-gen/agent",
"thrift-gen/baggage",
"thrift-gen/jaeger",
"thrift-gen/sampling",
"thrift-gen/zipkincore",
"transport",
"utils",
]
pruneopts = "UT"
revision = "1a782e2da844727691fef1757c72eb190c2909f0"
version = "v2.15.0"
[[projects]]
digest = "1:0f09db8429e19d57c8346ad76fbbc679341fa86073d3b8fb5ac919f0357d8f4c"
name = "github.com/uber/jaeger-lib"
packages = ["metrics"]
pruneopts = "UT"
revision = "ed3a127ec5fef7ae9ea95b01b542c47fbd999ce5"
version = "v1.5.0"
[[projects]]
branch = "master"
digest = "1:76ee51c3f468493aff39dbacc401e8831fbb765104cbf613b89bef01cf4bad70"
name = "golang.org/x/net"
packages = ["context"]
pruneopts = "UT"
revision = "a544f70c90f196e50d198126db0c4cb2b562fec0"
[[projects]]
digest = "1:04aea75705cb453e24bf8c1506a24a5a9036537dbc61ddf71d20900d6c7c3ab9"
name = "gopkg.in/DATA-DOG/go-sqlmock.v1" name = "gopkg.in/DATA-DOG/go-sqlmock.v1"
packages = ["."] packages = ["."]
pruneopts = "UT"
revision = "d76b18b42f285b792bf985118980ce9eacea9d10" revision = "d76b18b42f285b792bf985118980ce9eacea9d10"
version = "v1.3.0" version = "v1.3.0"
[[projects]] [[projects]]
digest = "1:e5a1379b4f0cad2aabd75580598c3b8e19a027e8eed806e7b76b0ec949df4599"
name = "gopkg.in/redis.v5" name = "gopkg.in/redis.v5"
packages = [ packages = [
".", ".",
@ -102,14 +168,34 @@
"internal/consistenthash", "internal/consistenthash",
"internal/hashtag", "internal/hashtag",
"internal/pool", "internal/pool",
"internal/proto" "internal/proto",
] ]
pruneopts = "UT"
revision = "a16aeec10ff407b1e7be6dd35797ccf5426ef0f0" revision = "a16aeec10ff407b1e7be6dd35797ccf5426ef0f0"
version = "v5.2.9" version = "v5.2.9"
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "2588ee54549a76e93e2e65a289fccd8b636f85b124c5ccb0ab3d5f3529a3cbaa" input-imports = [
"github.com/apex/log",
"github.com/apex/log/handlers/text",
"github.com/aws/aws-sdk-go/aws",
"github.com/aws/aws-sdk-go/aws/awserr",
"github.com/aws/aws-sdk-go/aws/request",
"github.com/aws/aws-sdk-go/aws/session",
"github.com/aws/aws-sdk-go/service/dynamodb",
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute",
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface",
"github.com/aws/aws-sdk-go/service/kinesis",
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface",
"github.com/lib/pq",
"github.com/opentracing/opentracing-go",
"github.com/opentracing/opentracing-go/ext",
"github.com/pkg/errors",
"github.com/uber/jaeger-client-go/config",
"gopkg.in/DATA-DOG/go-sqlmock.v1",
"gopkg.in/redis.v5",
]
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View file

@ -258,6 +258,9 @@ func main() {
} }
``` ```
### Opentracing
To enable integraton with Opentracing. Checkpoint, Consumer are now required to pass in context as first parameter. Context object wraps tracing context within and is required to pass down to other layer. Another change, that should be invisible from user is that, all AWS SDK GO call are now using the version WithContext, e.g. if codebase is using GetID(...), now they are replaced with GetIDWithContext(ctx,...). This is done so we can link the span created for AWS call to spans created upstream within application code.
## Contributing ## Contributing
Please see [CONTRIBUTING.md] for more information. Thank you, [contributors]! Please see [CONTRIBUTING.md] for more information. Thank you, [contributors]!

View file

@ -1,13 +1,17 @@
package consumer package consumer
import (
"context"
)
// Checkpoint interface used track consumer progress in the stream // Checkpoint interface used track consumer progress in the stream
type Checkpoint interface { type Checkpoint interface {
Get(streamName, shardID string) (string, error) Get(ctx context.Context, streamName, shardID string) (string, error)
Set(streamName, shardID, sequenceNumber string) error Set(ctx context.Context, streamName, shardID, sequenceNumber string) error
} }
// noopCheckpoint implements the checkpoint interface with discard // noopCheckpoint implements the checkpoint interface with discard
type noopCheckpoint struct{} type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string, string) error { return nil } func (n noopCheckpoint) Set(context.Context, string, string, string) error { return nil }
func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil } func (n noopCheckpoint) Get(context.Context, string, string) (string, error) { return "", nil }

View file

@ -1,6 +1,7 @@
package ddb package ddb
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"sync" "sync"
@ -11,6 +12,8 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
) )
// Option is used to override defaults when creating a new Checkpoint // Option is used to override defaults when creating a new Checkpoint
@ -38,7 +41,7 @@ func WithRetryer(r Retryer) Option {
} }
// New returns a checkpoint that uses DynamoDB for underlying storage // New returns a checkpoint that uses DynamoDB for underlying storage
func New(appName, tableName string, opts ...Option) (*Checkpoint, error) { func New(ctx context.Context, appName, tableName string, opts ...Option) (*Checkpoint, error) {
client := dynamodb.New(session.New(aws.NewConfig())) client := dynamodb.New(session.New(aws.NewConfig()))
ck := &Checkpoint{ ck := &Checkpoint{
@ -56,7 +59,7 @@ func New(appName, tableName string, opts ...Option) (*Checkpoint, error) {
opt(ck) opt(ck)
} }
go ck.loop() go ck.loop(ctx)
return ck, nil return ck, nil
} }
@ -87,9 +90,13 @@ type item struct {
// Get determines if a checkpoint for a particular Shard exists. // Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with // Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) Get(streamName, shardID string) (string, error) { func (c *Checkpoint) Get(ctx context.Context, streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName) namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
span, ctx := opentracing.StartSpanFromContext(ctx, "checkpoint.ddb.Get",
opentracing.Tag{Key: "namespace", Value: namespace},
opentracing.Tag{Key: "shardID", Value: shardID},
)
defer span.Finish()
params := &dynamodb.GetItemInput{ params := &dynamodb.GetItemInput{
TableName: aws.String(c.tableName), TableName: aws.String(c.tableName),
ConsistentRead: aws.Bool(true), ConsistentRead: aws.Bool(true),
@ -103,11 +110,13 @@ func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
}, },
} }
resp, err := c.client.GetItem(params) resp, err := c.client.GetItemWithContext(ctx, params)
if err != nil { if err != nil {
if c.retryer.ShouldRetry(err) { if c.retryer.ShouldRetry(err) {
return c.Get(streamName, shardID) return c.Get(ctx, streamName, shardID)
} }
span.LogKV("checkpoint get item error", err.Error())
ext.Error.Set(span, true)
return "", err return "", err
} }
@ -118,10 +127,14 @@ func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point. // Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
span, ctx := opentracing.StartSpanFromContext(ctx, "checkpoint.ddb.Set",
opentracing.Tag{Key: "stream.name", Value: streamName},
opentracing.Tag{Key: "shardID", Value: shardID},
)
defer span.Finish()
if sequenceNumber == "" { if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty") return fmt.Errorf("sequence number should not be empty")
} }
@ -136,12 +149,12 @@ func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
} }
// Shutdown the checkpoint. Save any in-flight data. // Shutdown the checkpoint. Save any in-flight data.
func (c *Checkpoint) Shutdown() error { func (c *Checkpoint) Shutdown(ctx context.Context) error {
c.done <- struct{}{} c.done <- struct{}{}
return c.save() return c.save(ctx)
} }
func (c *Checkpoint) loop() { func (c *Checkpoint) loop(ctx context.Context) {
tick := time.NewTicker(c.maxInterval) tick := time.NewTicker(c.maxInterval)
defer tick.Stop() defer tick.Stop()
defer close(c.done) defer close(c.done)
@ -149,16 +162,18 @@ func (c *Checkpoint) loop() {
for { for {
select { select {
case <-tick.C: case <-tick.C:
c.save() c.save(ctx)
case <-c.done: case <-c.done:
return return
} }
} }
} }
func (c *Checkpoint) save() error { func (c *Checkpoint) save(ctx context.Context) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
span, ctx := opentracing.StartSpanFromContext(ctx, "checkpoint.ddb.save")
defer span.Finish()
for key, sequenceNumber := range c.checkpoints { for key, sequenceNumber := range c.checkpoints {
item, err := dynamodbattribute.MarshalMap(item{ item, err := dynamodbattribute.MarshalMap(item{
@ -168,10 +183,12 @@ func (c *Checkpoint) save() error {
}) })
if err != nil { if err != nil {
log.Printf("marshal map error: %v", err) log.Printf("marshal map error: %v", err)
span.LogKV("marshal map error", err.Error())
ext.Error.Set(span, true)
return nil return nil
} }
_, err = c.client.PutItem(&dynamodb.PutItemInput{ _, err = c.client.PutItemWithContext(ctx, &dynamodb.PutItemInput{
TableName: aws.String(c.tableName), TableName: aws.String(c.tableName),
Item: item, Item: item,
}) })
@ -179,7 +196,9 @@ func (c *Checkpoint) save() error {
if !c.retryer.ShouldRetry(err) { if !c.retryer.ShouldRetry(err) {
return err return err
} }
return c.save() span.LogKV("checkpoint put item error", err.Error())
ext.Error.Set(span, true)
return c.save(ctx)
} }
} }

View file

@ -1,6 +1,7 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -80,7 +81,7 @@ func (c *Checkpoint) GetMaxInterval() time.Duration {
// Get determines if a checkpoint for a particular Shard exists. // Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with // Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) Get(streamName, shardID string) (string, error) { func (c *Checkpoint) Get(ctx context.Context, streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName) namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string var sequenceNumber string
@ -99,7 +100,7 @@ func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point. // Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()

View file

@ -1,17 +1,16 @@
package postgres_test package postgres_test
import ( import (
"context"
"database/sql"
"fmt"
"testing" "testing"
"time" "time"
"fmt" "gopkg.in/DATA-DOG/go-sqlmock.v1"
"database/sql"
"github.com/harlow/kinesis-consumer/checkpoint/postgres" "github.com/harlow/kinesis-consumer/checkpoint/postgres"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/DATA-DOG/go-sqlmock.v1"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
@ -77,6 +76,7 @@ func TestNew_WithMaxIntervalOption(t *testing.T) {
} }
func TestCheckpoint_Get(t *testing.T) { func TestCheckpoint_Get(t *testing.T) {
ctx := context.TODO()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -102,7 +102,7 @@ func TestCheckpoint_Get(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows)
gotSequenceNumber, err := ck.Get(streamName, shardID) gotSequenceNumber, err := ck.Get(ctx, streamName, shardID)
if gotSequenceNumber != expectedSequenceNumber { if gotSequenceNumber != expectedSequenceNumber {
t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber) t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber)
@ -117,6 +117,7 @@ func TestCheckpoint_Get(t *testing.T) {
} }
func TestCheckpoint_Get_NoRows(t *testing.T) { func TestCheckpoint_Get_NoRows(t *testing.T) {
ctx := context.TODO()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -138,7 +139,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows)
gotSequenceNumber, err := ck.Get(streamName, shardID) gotSequenceNumber, err := ck.Get(ctx, streamName, shardID)
if gotSequenceNumber != "" { if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -153,6 +154,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) {
} }
func TestCheckpoint_Get_QueryError(t *testing.T) { func TestCheckpoint_Get_QueryError(t *testing.T) {
ctx := context.TODO()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -174,7 +176,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
tableName) tableName)
mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error")) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error"))
gotSequenceNumber, err := ck.Get(streamName, shardID) gotSequenceNumber, err := ck.Get(ctx, streamName, shardID)
if gotSequenceNumber != "" { if gotSequenceNumber != "" {
t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber)
@ -189,6 +191,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) {
} }
func TestCheckpoint_Set(t *testing.T) { func TestCheckpoint_Set(t *testing.T) {
ctx := context.TODO()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -201,7 +204,7 @@ func TestCheckpoint_Set(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
} }
err = ck.Set(streamName, shardID, expectedSequenceNumber) err = ck.Set(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Errorf("expected error equals nil, but got %v", err) t.Errorf("expected error equals nil, but got %v", err)
@ -210,6 +213,7 @@ func TestCheckpoint_Set(t *testing.T) {
} }
func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
ctx := context.TODO()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -222,7 +226,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) t.Fatalf("error occurred during the checkpoint creation. cause: %v", err)
} }
err = ck.Set(streamName, shardID, expectedSequenceNumber) err = ck.Set(ctx, streamName, shardID, expectedSequenceNumber)
if err == nil { if err == nil {
t.Errorf("expected error equals not nil, but got %v", err) t.Errorf("expected error equals not nil, but got %v", err)
@ -231,6 +235,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) {
} }
func TestCheckpoint_Shutdown(t *testing.T) { func TestCheckpoint_Shutdown(t *testing.T) {
ctx := context.TODO()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -253,7 +258,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
result := sqlmock.NewResult(0, 1) result := sqlmock.NewResult(0, 1)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result)
err = ck.Set(streamName, shardID, expectedSequenceNumber) err = ck.Set(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)
@ -270,6 +275,7 @@ func TestCheckpoint_Shutdown(t *testing.T) {
} }
func TestCheckpoint_Shutdown_SaveError(t *testing.T) { func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
ctx := context.TODO()
appName := "streamConsumer" appName := "streamConsumer"
tableName := "checkpoint" tableName := "checkpoint"
connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;"
@ -291,7 +297,7 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) {
expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName) expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName)
mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error")) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error"))
err = ck.Set(streamName, shardID, expectedSequenceNumber) err = ck.Set(ctx, streamName, shardID, expectedSequenceNumber)
if err != nil { if err != nil {
t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err)

View file

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"context"
"fmt" "fmt"
"os" "os"
@ -37,14 +38,14 @@ type Checkpoint struct {
} }
// Get fetches the checkpoint for a particular Shard. // Get fetches the checkpoint for a particular Shard.
func (c *Checkpoint) Get(streamName, shardID string) (string, error) { func (c *Checkpoint) Get(ctx context.Context, streamName, shardID string) (string, error) {
val, _ := c.client.Get(c.key(streamName, shardID)).Result() val, _ := c.client.Get(c.key(streamName, shardID)).Result()
return val, nil return val, nil
} }
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point. // Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error { func (c *Checkpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error {
if sequenceNumber == "" { if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty") return fmt.Errorf("sequence number should not be empty")
} }

View file

@ -1,21 +1,23 @@
package redis package redis
import ( import (
"context"
"testing" "testing"
) )
func Test_CheckpointLifecycle(t *testing.T) { func Test_CheckpointLifecycle(t *testing.T) {
// new // new
ctx := context.TODO()
c, err := New("app") c, err := New("app")
if err != nil { if err != nil {
t.Fatalf("new checkpoint error: %v", err) t.Fatalf("new checkpoint error: %v", err)
} }
// set // set
c.Set("streamName", "shardID", "testSeqNum") c.Set(ctx, "streamName", "shardID", "testSeqNum")
// get // get
val, err := c.Get("streamName", "shardID") val, err := c.Get(ctx, "streamName", "shardID")
if err != nil { if err != nil {
t.Fatalf("get checkpoint error: %v", err) t.Fatalf("get checkpoint error: %v", err)
} }
@ -25,12 +27,13 @@ func Test_CheckpointLifecycle(t *testing.T) {
} }
func Test_SetEmptySeqNum(t *testing.T) { func Test_SetEmptySeqNum(t *testing.T) {
ctx := context.TODO()
c, err := New("app") c, err := New("app")
if err != nil { if err != nil {
t.Fatalf("new checkpoint error: %v", err) t.Fatalf("new checkpoint error: %v", err)
} }
err = c.Set("streamName", "shardID", "") err = c.Set(ctx, "streamName", "shardID", "")
if err == nil { if err == nil {
t.Fatalf("should not allow empty sequence number") t.Fatalf("should not allow empty sequence number")
} }

View file

@ -112,6 +112,8 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error
// get shard ids // get shard ids
shardIDs, err := c.getShardIDs(ctx, c.streamName) shardIDs, err := c.getShardIDs(ctx, c.streamName)
span.SetTag("stream.name", c.streamName)
span.SetTag("shard.count", len(shardIDs))
if err != nil { if err != nil {
span.LogKV("get shardID error", err.Error(), "stream.name", c.streamName) span.LogKV("get shardID error", err.Error(), "stream.name", c.streamName)
ext.Error.Set(span, true) ext.Error.Set(span, true)
@ -119,7 +121,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error
} }
if len(shardIDs) == 0 { if len(shardIDs) == 0 {
span.LogKV("get shardID error", err.Error(), "stream.name", c.streamName, "shards.count", len(shardIDs)) span.LogKV("stream.name", c.streamName, "shards.count", len(shardIDs))
ext.Error.Set(span, true) ext.Error.Set(span, true)
return fmt.Errorf("no shards available") return fmt.Errorf("no shards available")
} }
@ -138,6 +140,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error
if err := c.ScanShard(ctx, shardID, fn); err != nil { if err := c.ScanShard(ctx, shardID, fn); err != nil {
span.LogKV("scan shard error", err.Error(), "shardID", shardID) span.LogKV("scan shard error", err.Error(), "shardID", shardID)
ext.Error.Set(span, true) ext.Error.Set(span, true)
span.Finish()
select { select {
case errc <- fmt.Errorf("shard %s error: %v", shardID, err): case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
// first error to occur // first error to occur
@ -166,7 +169,7 @@ func (c *Consumer) ScanShard(
span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.scanshard") span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.scanshard")
defer span.Finish() defer span.Finish()
// get checkpoint // get checkpoint
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) lastSeqNum, err := c.checkpoint.Get(ctx, c.streamName, shardID)
if err != nil { if err != nil {
span.LogKV("checkpoint error", err.Error(), "shardID", shardID) span.LogKV("checkpoint error", err.Error(), "shardID", shardID)
ext.Error.Set(span, true) ext.Error.Set(span, true)
@ -195,7 +198,7 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
return nil return nil
default: default:
span.SetTag("scan", "on") span.SetTag("scan", "on")
resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{ resp, err := c.client.GetRecordsWithContext(ctx, &kinesis.GetRecordsInput{
ShardIterator: shardIterator, ShardIterator: shardIterator,
}) })
@ -203,6 +206,7 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str
shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil { if err != nil {
ext.Error.Set(span, true) ext.Error.Set(span, true)
span.LogKV("get shard iterator error", err.Error())
return fmt.Errorf("get shard iterator error: %v", err) return fmt.Errorf("get shard iterator error: %v", err)
} }
continue continue
@ -243,7 +247,7 @@ func (c *Consumer) handleRecord(ctx context.Context, shardID string, r *Record,
status := fn(r) status := fn(r)
if !status.SkipCheckpoint { if !status.SkipCheckpoint {
span.LogKV("scan.state", status) span.LogKV("scan.state", status)
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { if err := c.checkpoint.Set(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil {
span.LogKV("checkpoint error", err.Error(), "stream.name", c.streamName, "shardID", shardID, "sequenceNumber", *r.SequenceNumber) span.LogKV("checkpoint error", err.Error(), "stream.name", c.streamName, "shardID", shardID, "sequenceNumber", *r.SequenceNumber)
ext.Error.Set(span, true) ext.Error.Set(span, true)
return false, err return false, err
@ -269,7 +273,7 @@ func (c *Consumer) getShardIDs(ctx context.Context, streamName string) ([]string
span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.getShardIDs") span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.getShardIDs")
defer span.Finish() defer span.Finish()
resp, err := c.client.DescribeStream( resp, err := c.client.DescribeStreamWithContext(ctx,
&kinesis.DescribeStreamInput{ &kinesis.DescribeStreamInput{
StreamName: aws.String(streamName), StreamName: aws.String(streamName),
}, },
@ -304,7 +308,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, la
params.ShardIteratorType = aws.String("TRIM_HORIZON") params.ShardIteratorType = aws.String("TRIM_HORIZON")
} }
resp, err := c.client.GetShardIterator(params) resp, err := c.client.GetShardIteratorWithContext(ctx, params)
if err != nil { if err != nil {
span.LogKV("get shard error", err.Error()) span.LogKV("get shard error", err.Error())
ext.Error.Set(span, true) ext.Error.Set(span, true)

View file

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
) )
@ -19,6 +20,7 @@ func TestNew(t *testing.T) {
} }
func TestConsumer_Scan(t *testing.T) { func TestConsumer_Scan(t *testing.T) {
ctx := context.TODO()
records := []*kinesis.Record{ records := []*kinesis.Record{
{ {
Data: []byte("firstData"), Data: []byte("firstData"),
@ -30,18 +32,18 @@ func TestConsumer_Scan(t *testing.T) {
}, },
} }
client := &kinesisClientMock{ client := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: records, Records: records,
}, nil }, nil
}, },
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { describeStreamMock: func(a aws.Context, input *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) {
return &kinesis.DescribeStreamOutput{ return &kinesis.DescribeStreamOutput{
StreamDescription: &kinesis.StreamDescription{ StreamDescription: &kinesis.StreamDescription{
Shards: []*kinesis.Shard{ Shards: []*kinesis.Shard{
@ -73,7 +75,7 @@ func TestConsumer_Scan(t *testing.T) {
return ScanStatus{} return ScanStatus{}
} }
if err := c.Scan(context.Background(), fn); err != nil { if err := c.Scan(ctx, fn); err != nil {
t.Errorf("scan shard error expected nil. got %v", err) t.Errorf("scan shard error expected nil. got %v", err)
} }
@ -87,15 +89,16 @@ func TestConsumer_Scan(t *testing.T) {
t.Errorf("counter error expected %d, got %d", 2, val) t.Errorf("counter error expected %d, got %d", 2, val)
} }
val, err := cp.Get("myStreamName", "myShard") val, err := cp.Get(ctx, "myStreamName", "myShard")
if err != nil && val != "lastSeqNum" { if err != nil && val != "lastSeqNum" {
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val) t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
} }
} }
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
ctx := context.TODO()
client := &kinesisClientMock{ client := &kinesisClientMock{
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { describeStreamMock: func(a aws.Context, input *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) {
return &kinesis.DescribeStreamOutput{ return &kinesis.DescribeStreamOutput{
StreamDescription: &kinesis.StreamDescription{ StreamDescription: &kinesis.StreamDescription{
Shards: make([]*kinesis.Shard, 0), Shards: make([]*kinesis.Shard, 0),
@ -123,7 +126,7 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
return ScanStatus{} return ScanStatus{}
} }
if err := c.Scan(context.Background(), fn); err == nil { if err := c.Scan(ctx, fn); err == nil {
t.Errorf("scan shard error expected not nil. got %v", err) t.Errorf("scan shard error expected not nil. got %v", err)
} }
@ -133,13 +136,14 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
if val := ctr.counter; val != 0 { if val := ctr.counter; val != 0 {
t.Errorf("counter error expected %d, got %d", 0, val) t.Errorf("counter error expected %d, got %d", 0, val)
} }
val, err := cp.Get("myStreamName", "myShard") val, err := cp.Get(ctx, "myStreamName", "myShard")
if err != nil && val != "" { if err != nil && val != "" {
t.Errorf("checkout error expected %s, got %s", "", val) t.Errorf("checkout error expected %s, got %s", "", val)
} }
} }
func TestScanShard(t *testing.T) { func TestScanShard(t *testing.T) {
ctx := context.TODO()
var records = []*kinesis.Record{ var records = []*kinesis.Record{
{ {
Data: []byte("firstData"), Data: []byte("firstData"),
@ -152,12 +156,12 @@ func TestScanShard(t *testing.T) {
} }
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: records, Records: records,
@ -187,7 +191,7 @@ func TestScanShard(t *testing.T) {
} }
// scan shard // scan shard
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { if err := c.ScanShard(ctx, "myShard", fn); err != nil {
t.Fatalf("scan shard error: %v", err) t.Fatalf("scan shard error: %v", err)
} }
@ -202,13 +206,14 @@ func TestScanShard(t *testing.T) {
} }
// sets checkpoint // sets checkpoint
val, err := cp.Get("myStreamName", "myShard") val, err := cp.Get(ctx, "myStreamName", "myShard")
if err != nil && val != "lastSeqNum" { if err != nil && val != "lastSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val) t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
} }
} }
func TestScanShard_StopScan(t *testing.T) { func TestScanShard_StopScan(t *testing.T) {
ctx := context.TODO()
var records = []*kinesis.Record{ var records = []*kinesis.Record{
{ {
Data: []byte("firstData"), Data: []byte("firstData"),
@ -221,12 +226,12 @@ func TestScanShard_StopScan(t *testing.T) {
} }
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: records, Records: records,
@ -246,7 +251,7 @@ func TestScanShard_StopScan(t *testing.T) {
return ScanStatus{StopScan: true} return ScanStatus{StopScan: true}
} }
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { if err := c.ScanShard(ctx, "myShard", fn); err != nil {
t.Fatalf("scan shard error: %v", err) t.Fatalf("scan shard error: %v", err)
} }
@ -256,13 +261,14 @@ func TestScanShard_StopScan(t *testing.T) {
} }
func TestScanShard_ShardIsClosed(t *testing.T) { func TestScanShard_ShardIsClosed(t *testing.T) {
ctx := context.TODO()
var client = &kinesisClientMock{ var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil }, nil
}, },
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ return &kinesis.GetRecordsOutput{
NextShardIterator: nil, NextShardIterator: nil,
Records: make([]*Record, 0), Records: make([]*Record, 0),
@ -279,28 +285,28 @@ func TestScanShard_ShardIsClosed(t *testing.T) {
return ScanStatus{} return ScanStatus{}
} }
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { if err := c.ScanShard(ctx, "myShard", fn); err != nil {
t.Fatalf("scan shard error: %v", err) t.Fatalf("scan shard error: %v", err)
} }
} }
type kinesisClientMock struct { type kinesisClientMock struct {
kinesisiface.KinesisAPI kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) getShardIteratorMock func(aws.Context, *kinesis.GetShardIteratorInput, ...request.Option) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) getRecordsMock func(aws.Context, *kinesis.GetRecordsInput, ...request.Option) (*kinesis.GetRecordsOutput, error)
describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) describeStreamMock func(aws.Context, *kinesis.DescribeStreamInput, ...request.Option) (*kinesis.DescribeStreamOutput, error)
} }
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { func (c *kinesisClientMock) GetRecordsWithContext(a aws.Context, in *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) {
return c.getRecordsMock(in) return c.getRecordsMock(a, in, o...)
} }
func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { func (c *kinesisClientMock) GetShardIteratorWithContext(a aws.Context, in *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) {
return c.getShardIteratorMock(in) return c.getShardIteratorMock(a, in, o...)
} }
func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { func (c *kinesisClientMock) DescribeStreamWithContext(a aws.Context, in *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) {
return c.describeStreamMock(in) return c.describeStreamMock(a, in, o...)
} }
// implementation of checkpoint // implementation of checkpoint
@ -309,7 +315,7 @@ type fakeCheckpoint struct {
mu sync.Mutex mu sync.Mutex
} }
func (fc *fakeCheckpoint) Set(streamName, shardID, sequenceNumber string) error { func (fc *fakeCheckpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error {
fc.mu.Lock() fc.mu.Lock()
defer fc.mu.Unlock() defer fc.mu.Unlock()
@ -318,7 +324,7 @@ func (fc *fakeCheckpoint) Set(streamName, shardID, sequenceNumber string) error
return nil return nil
} }
func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) { func (fc *fakeCheckpoint) Get(ctx context.Context, streamName, shardID string) (string, error) {
fc.mu.Lock() fc.mu.Lock()
defer fc.mu.Unlock() defer fc.mu.Unlock()

View file

@ -16,14 +16,19 @@ import (
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis"
"github.com/opentracing/opentracing-go"
alog "github.com/apex/log" alog "github.com/apex/log"
"github.com/apex/log/handlers/text" "github.com/apex/log/handlers/text"
consumer "github.com/harlow/kinesis-consumer" consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb" checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb"
"github.com/harlow/kinesis-consumer/examples/distributed-tracing/utility"
) )
const serviceName = "checkpoint.dynamodb"
// kick off a server for exposing scan metrics // kick off a server for exposing scan metrics
func init() { func init() {
sock, err := net.Listen("tcp", "localhost:8080") sock, err := net.Listen("tcp", "localhost:8080")
@ -62,6 +67,12 @@ func main() {
) )
flag.Parse() flag.Parse()
tracer, closer := utility.NewTracer(serviceName)
defer closer.Close()
opentracing.InitGlobalTracer(tracer)
span := tracer.StartSpan("consumer.main")
ctx := opentracing.ContextWithSpan(context.Background(), span)
// Following will overwrite the default dynamodb client // Following will overwrite the default dynamodb client
// Older versions of aws sdk does not picking up aws config properly. // Older versions of aws sdk does not picking up aws config properly.
// You probably need to update aws sdk verison. Tested the following with 1.13.59 // You probably need to update aws sdk verison. Tested the following with 1.13.59
@ -72,7 +83,7 @@ func main() {
) )
// ddb checkpoint // ddb checkpoint
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{})) ck, err := checkpoint.New(ctx, *app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{}))
if err != nil { if err != nil {
log.Log("checkpoint error: %v", err) log.Log("checkpoint error: %v", err)
} }
@ -121,7 +132,7 @@ func main() {
log.Log("scan error: %v", err) log.Log("scan error: %v", err)
} }
if err := ck.Shutdown(); err != nil { if err := ck.Shutdown(ctx); err != nil {
log.Log("checkpoint shutdown error: %v", err) log.Log("checkpoint shutdown error: %v", err)
} }
} }

View file

@ -47,17 +47,17 @@ func main() {
span := tracer.StartSpan("consumer.main") span := tracer.StartSpan("consumer.main")
defer span.Finish() defer span.Finish()
var ( app := flag.String("app", "", "App name")
app = flag.String("app", "", "App name") stream := flag.String("stream", "", "Stream name")
stream = flag.String("stream", "", "Stream name") table := flag.String("table", "", "Checkpoint table name")
table = flag.String("table", "", "Checkpoint table name")
)
flag.Parse() flag.Parse()
span.SetTag("app.name", app) span.SetTag("app.name", app)
span.SetTag("stream.name", stream) span.SetTag("stream.name", stream)
span.SetTag("table.name", table) span.SetTag("table.name", table)
fmt.Println("set tag....")
// Following will overwrite the default dynamodb client // Following will overwrite the default dynamodb client
// Older versions of aws sdk does not picking up aws config properly. // Older versions of aws sdk does not picking up aws config properly.
// You probably need to update aws sdk verison. Tested the following with 1.13.59 // You probably need to update aws sdk verison. Tested the following with 1.13.59
@ -67,8 +67,9 @@ func main() {
myDynamoDbClient := dynamodb.New(sess) myDynamoDbClient := dynamodb.New(sess)
// ddb checkpoint // ddb checkpoint
ctx := opentracing.ContextWithSpan(context.Background(), span)
retryer := utility.NewRetryer() retryer := utility.NewRetryer()
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(retryer)) ck, err := checkpoint.New(ctx, *app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(retryer))
if err != nil { if err != nil {
span.LogKV("checkpoint error", err.Error()) span.LogKV("checkpoint error", err.Error())
span.SetTag("consumer.retry.count", retryer.Count()) span.SetTag("consumer.retry.count", retryer.Count())
@ -97,7 +98,8 @@ func main() {
} }
// use cancel func to signal shutdown // use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background()) ctx = opentracing.ContextWithSpan(ctx, span)
ctx, cancel := context.WithCancel(ctx)
// trap SIGINT, wait to trigger shutdown // trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1) signals := make(chan os.Signal, 1)
@ -105,23 +107,25 @@ func main() {
go func() { go func() {
<-signals <-signals
span.Finish()
closer.Close()
cancel() cancel()
}() }()
// scan stream // scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus { err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data)) fmt.Println(string(r.Data))
// continue scanning // continue scanning
return consumer.ScanStatus{} return consumer.ScanStatus{}
}) })
if err != nil { if err != nil {
span.LogKV("consumer scan error", err.Error())
ext.Error.Set(span, true) //span.LogKV("consumer scan error", err.Error())
//ext.Error.Set(span, true)
log.Log("consumer scan error", "error", err.Error()) log.Log("consumer scan error", "error", err.Error())
} }
if err := ck.Shutdown(); err != nil { if err := ck.Shutdown(ctx); err != nil {
span.LogKV("consumer shutdown error", err.Error()) span.LogKV("consumer shutdown error", err.Error())
ext.Error.Set(span, true) ext.Error.Set(span, true)
log.Log("checkpoint shutdown error", "error", err.Error()) log.Log("checkpoint shutdown error", "error", err.Error())

View file

@ -56,7 +56,7 @@ func main() {
// Need to end span here, since Fatalf calls os.Exit // Need to end span here, since Fatalf calls os.Exit
span.Finish() span.Finish()
closer.Close() closer.Close()
log.Fatal(fmt.Sprintf("Cannot open %s file"), dataFile) log.Fatal(fmt.Sprintf("Cannot open %s file", dataFile))
} }
defer f.Close() defer f.Close()
span.SetTag("producer.file.name", f.Name()) span.SetTag("producer.file.name", f.Name())
@ -88,11 +88,11 @@ func main() {
func putRecords(ctx context.Context, streamName *string, records []*kinesis.PutRecordsRequestEntry) { func putRecords(ctx context.Context, streamName *string, records []*kinesis.PutRecordsRequestEntry) {
// I am assuming each new AWS call is a new Span // I am assuming each new AWS call is a new Span
span, _ := opentracing.StartSpanFromContext(ctx, "producer.putRecords") span, ctx := opentracing.StartSpanFromContext(ctx, "producer.putRecords")
defer span.Finish() defer span.Finish()
span.SetTag("producer.records.count", len(records)) span.SetTag("producer.records.count", len(records))
ctx = opentracing.ContextWithSpan(ctx, span) ctx = opentracing.ContextWithSpan(ctx, span)
_, err := svc.PutRecordsWithContext(&kinesis.PutRecordsInput{ _, err := svc.PutRecordsWithContext(ctx, &kinesis.PutRecordsInput{
StreamName: streamName, StreamName: streamName,
Records: records, Records: records,
}) })