diff --git a/Gopkg.lock b/Gopkg.lock index 76fcb47..84b2794 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -2,15 +2,18 @@ [[projects]] + digest = "1:5bbabe0c3c7e7f524b4c38193b80bf24624e67c0f3a036c4244c85c9a80579fd" name = "github.com/apex/log" packages = [ ".", - "handlers/text" + "handlers/text", ] + pruneopts = "UT" revision = "0296d6eb16bb28f8a0c55668affcf4876dc269be" version = "v1.0.0" [[projects]] + digest = "1:430a0049ba9e5652a778f1bb2a755b456ef8de588d94093f0b02a63cb885fbca" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -46,55 +49,118 @@ "service/dynamodb/dynamodbiface", "service/kinesis", "service/kinesis/kinesisiface", - "service/sts" + "service/sts", ] + pruneopts = "UT" revision = "8475c414b1bd58b8cc214873a8854e3a621e67d7" version = "v1.15.0" [[projects]] + branch = "master" + digest = "1:4c4c33075b704791d6a7f09dfb55c66769e8a1dc6adf87026292d274fe8ad113" + name = "github.com/codahale/hdrhistogram" + packages = ["."] + pruneopts = "UT" + revision = "3a0bb77429bd3a61596f5e8a3172445844342120" + +[[projects]] + digest = "1:fe8a03a8222d5b913f256972933d26d24ad7c8286692a42943bc01633cc8fce3" name = "github.com/go-ini/ini" packages = ["."] + pruneopts = "UT" revision = "358ee7663966325963d4e8b2e1fbd570c5195153" version = "v1.38.1" [[projects]] - name = "github.com/harlow/kinesis-consumer" - packages = [ - ".", - "checkpoint/ddb", - "checkpoint/postgres", - "checkpoint/redis" - ] - revision = "049445e259a2ab9146364bf60d6f5f71270a125b" - version = "v0.2.0" - -[[projects]] + digest = "1:e22af8c7518e1eab6f2eab2b7d7558927f816262586cd6ed9f349c97a6c285c4" name = "github.com/jmespath/go-jmespath" packages = ["."] + pruneopts = "UT" revision = "0b12d6b5" [[projects]] branch = "master" + digest = "1:37ce7d7d80531b227023331002c0d42b4b4b291a96798c82a049d03a54ba79e4" name = "github.com/lib/pq" packages = [ ".", - "oid" + "oid", ] + pruneopts = "UT" revision = "90697d60dd844d5ef6ff15135d0203f65d2f53b8" [[projects]] + digest = "1:450b7623b185031f3a456801155c8320209f75d0d4c4e633c6b1e59d44d6e392" + name = "github.com/opentracing/opentracing-go" + packages = [ + ".", + "ext", + "log", + ] + pruneopts = "UT" + revision = "1949ddbfd147afd4d964a9f00b24eb291e0e7c38" + version = "v1.0.2" + +[[projects]] + digest = "1:40e195917a951a8bf867cd05de2a46aaf1806c50cf92eebf4c16f78cd196f747" name = "github.com/pkg/errors" packages = ["."] + pruneopts = "UT" revision = "645ef00459ed84a119197bfb8d8205042c6df63d" version = "v0.8.0" [[projects]] + digest = "1:ac6f26e917fd2fb3194a7ebe2baf6fb32de2f2fbfed130c18aac0e758a6e1d22" + name = "github.com/uber/jaeger-client-go" + packages = [ + ".", + "config", + "internal/baggage", + "internal/baggage/remote", + "internal/spanlog", + "internal/throttler", + "internal/throttler/remote", + "log", + "rpcmetrics", + "thrift", + "thrift-gen/agent", + "thrift-gen/baggage", + "thrift-gen/jaeger", + "thrift-gen/sampling", + "thrift-gen/zipkincore", + "transport", + "utils", + ] + pruneopts = "UT" + revision = "1a782e2da844727691fef1757c72eb190c2909f0" + version = "v2.15.0" + +[[projects]] + digest = "1:0f09db8429e19d57c8346ad76fbbc679341fa86073d3b8fb5ac919f0357d8f4c" + name = "github.com/uber/jaeger-lib" + packages = ["metrics"] + pruneopts = "UT" + revision = "ed3a127ec5fef7ae9ea95b01b542c47fbd999ce5" + version = "v1.5.0" + +[[projects]] + branch = "master" + digest = "1:76ee51c3f468493aff39dbacc401e8831fbb765104cbf613b89bef01cf4bad70" + name = "golang.org/x/net" + packages = ["context"] + pruneopts = "UT" + revision = "a544f70c90f196e50d198126db0c4cb2b562fec0" + +[[projects]] + digest = "1:04aea75705cb453e24bf8c1506a24a5a9036537dbc61ddf71d20900d6c7c3ab9" name = "gopkg.in/DATA-DOG/go-sqlmock.v1" packages = ["."] + pruneopts = "UT" revision = "d76b18b42f285b792bf985118980ce9eacea9d10" version = "v1.3.0" [[projects]] + digest = "1:e5a1379b4f0cad2aabd75580598c3b8e19a027e8eed806e7b76b0ec949df4599" name = "gopkg.in/redis.v5" packages = [ ".", @@ -102,14 +168,34 @@ "internal/consistenthash", "internal/hashtag", "internal/pool", - "internal/proto" + "internal/proto", ] + pruneopts = "UT" revision = "a16aeec10ff407b1e7be6dd35797ccf5426ef0f0" version = "v5.2.9" [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "2588ee54549a76e93e2e65a289fccd8b636f85b124c5ccb0ab3d5f3529a3cbaa" + input-imports = [ + "github.com/apex/log", + "github.com/apex/log/handlers/text", + "github.com/aws/aws-sdk-go/aws", + "github.com/aws/aws-sdk-go/aws/awserr", + "github.com/aws/aws-sdk-go/aws/request", + "github.com/aws/aws-sdk-go/aws/session", + "github.com/aws/aws-sdk-go/service/dynamodb", + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute", + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface", + "github.com/aws/aws-sdk-go/service/kinesis", + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface", + "github.com/lib/pq", + "github.com/opentracing/opentracing-go", + "github.com/opentracing/opentracing-go/ext", + "github.com/pkg/errors", + "github.com/uber/jaeger-client-go/config", + "gopkg.in/DATA-DOG/go-sqlmock.v1", + "gopkg.in/redis.v5", + ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/README.md b/README.md index 4658cd5..69d2151 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,9 @@ func main() { } ``` +### Opentracing +To enable integraton with Opentracing. Checkpoint, Consumer are now required to pass in context as first parameter. Context object wraps tracing context within and is required to pass down to other layer. Another change, that should be invisible from user is that, all AWS SDK GO call are now using the version WithContext, e.g. if codebase is using GetID(...), now they are replaced with GetIDWithContext(ctx,...). This is done so we can link the span created for AWS call to spans created upstream within application code. + ## Contributing Please see [CONTRIBUTING.md] for more information. Thank you, [contributors]! diff --git a/checkpoint.go b/checkpoint.go index 383d4c1..c55e554 100644 --- a/checkpoint.go +++ b/checkpoint.go @@ -1,13 +1,17 @@ package consumer +import ( + "context" +) + // Checkpoint interface used track consumer progress in the stream type Checkpoint interface { - Get(streamName, shardID string) (string, error) - Set(streamName, shardID, sequenceNumber string) error + Get(ctx context.Context, streamName, shardID string) (string, error) + Set(ctx context.Context, streamName, shardID, sequenceNumber string) error } // noopCheckpoint implements the checkpoint interface with discard type noopCheckpoint struct{} -func (n noopCheckpoint) Set(string, string, string) error { return nil } -func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil } +func (n noopCheckpoint) Set(context.Context, string, string, string) error { return nil } +func (n noopCheckpoint) Get(context.Context, string, string) (string, error) { return "", nil } diff --git a/checkpoint/ddb/ddb.go b/checkpoint/ddb/ddb.go index 170ced4..e4ff6a4 100644 --- a/checkpoint/ddb/ddb.go +++ b/checkpoint/ddb/ddb.go @@ -1,6 +1,7 @@ package ddb import ( + "context" "fmt" "log" "sync" @@ -11,6 +12,8 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" ) // Option is used to override defaults when creating a new Checkpoint @@ -38,7 +41,7 @@ func WithRetryer(r Retryer) Option { } // New returns a checkpoint that uses DynamoDB for underlying storage -func New(appName, tableName string, opts ...Option) (*Checkpoint, error) { +func New(ctx context.Context, appName, tableName string, opts ...Option) (*Checkpoint, error) { client := dynamodb.New(session.New(aws.NewConfig())) ck := &Checkpoint{ @@ -56,7 +59,7 @@ func New(appName, tableName string, opts ...Option) (*Checkpoint, error) { opt(ck) } - go ck.loop() + go ck.loop(ctx) return ck, nil } @@ -87,9 +90,13 @@ type item struct { // Get determines if a checkpoint for a particular Shard exists. // Typically used to determine whether we should start processing the shard with // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). -func (c *Checkpoint) Get(streamName, shardID string) (string, error) { +func (c *Checkpoint) Get(ctx context.Context, streamName, shardID string) (string, error) { namespace := fmt.Sprintf("%s-%s", c.appName, streamName) - + span, ctx := opentracing.StartSpanFromContext(ctx, "checkpoint.ddb.Get", + opentracing.Tag{Key: "namespace", Value: namespace}, + opentracing.Tag{Key: "shardID", Value: shardID}, + ) + defer span.Finish() params := &dynamodb.GetItemInput{ TableName: aws.String(c.tableName), ConsistentRead: aws.Bool(true), @@ -103,11 +110,13 @@ func (c *Checkpoint) Get(streamName, shardID string) (string, error) { }, } - resp, err := c.client.GetItem(params) + resp, err := c.client.GetItemWithContext(ctx, params) if err != nil { if c.retryer.ShouldRetry(err) { - return c.Get(streamName, shardID) + return c.Get(ctx, streamName, shardID) } + span.LogKV("checkpoint get item error", err.Error()) + ext.Error.Set(span, true) return "", err } @@ -118,10 +127,14 @@ func (c *Checkpoint) Get(streamName, shardID string) (string, error) { // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon failover, record processing is resumed from this point. -func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error { +func (c *Checkpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error { c.mu.Lock() defer c.mu.Unlock() - + span, ctx := opentracing.StartSpanFromContext(ctx, "checkpoint.ddb.Set", + opentracing.Tag{Key: "stream.name", Value: streamName}, + opentracing.Tag{Key: "shardID", Value: shardID}, + ) + defer span.Finish() if sequenceNumber == "" { return fmt.Errorf("sequence number should not be empty") } @@ -136,12 +149,12 @@ func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error { } // Shutdown the checkpoint. Save any in-flight data. -func (c *Checkpoint) Shutdown() error { +func (c *Checkpoint) Shutdown(ctx context.Context) error { c.done <- struct{}{} - return c.save() + return c.save(ctx) } -func (c *Checkpoint) loop() { +func (c *Checkpoint) loop(ctx context.Context) { tick := time.NewTicker(c.maxInterval) defer tick.Stop() defer close(c.done) @@ -149,16 +162,18 @@ func (c *Checkpoint) loop() { for { select { case <-tick.C: - c.save() + c.save(ctx) case <-c.done: return } } } -func (c *Checkpoint) save() error { +func (c *Checkpoint) save(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() + span, ctx := opentracing.StartSpanFromContext(ctx, "checkpoint.ddb.save") + defer span.Finish() for key, sequenceNumber := range c.checkpoints { item, err := dynamodbattribute.MarshalMap(item{ @@ -168,10 +183,12 @@ func (c *Checkpoint) save() error { }) if err != nil { log.Printf("marshal map error: %v", err) + span.LogKV("marshal map error", err.Error()) + ext.Error.Set(span, true) return nil } - _, err = c.client.PutItem(&dynamodb.PutItemInput{ + _, err = c.client.PutItemWithContext(ctx, &dynamodb.PutItemInput{ TableName: aws.String(c.tableName), Item: item, }) @@ -179,7 +196,9 @@ func (c *Checkpoint) save() error { if !c.retryer.ShouldRetry(err) { return err } - return c.save() + span.LogKV("checkpoint put item error", err.Error()) + ext.Error.Set(span, true) + return c.save(ctx) } } diff --git a/checkpoint/postgres/postgres.go b/checkpoint/postgres/postgres.go index b5a5bda..ad75996 100644 --- a/checkpoint/postgres/postgres.go +++ b/checkpoint/postgres/postgres.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "database/sql" "errors" "fmt" @@ -80,7 +81,7 @@ func (c *Checkpoint) GetMaxInterval() time.Duration { // Get determines if a checkpoint for a particular Shard exists. // Typically used to determine whether we should start processing the shard with // TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists). -func (c *Checkpoint) Get(streamName, shardID string) (string, error) { +func (c *Checkpoint) Get(ctx context.Context, streamName, shardID string) (string, error) { namespace := fmt.Sprintf("%s-%s", c.appName, streamName) var sequenceNumber string @@ -99,7 +100,7 @@ func (c *Checkpoint) Get(streamName, shardID string) (string, error) { // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon failover, record processing is resumed from this point. -func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error { +func (c *Checkpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error { c.mu.Lock() defer c.mu.Unlock() diff --git a/checkpoint/postgres/postgres_test.go b/checkpoint/postgres/postgres_test.go index 135fd0d..29ba5d3 100644 --- a/checkpoint/postgres/postgres_test.go +++ b/checkpoint/postgres/postgres_test.go @@ -1,17 +1,16 @@ package postgres_test import ( + "context" + "database/sql" + "fmt" "testing" - "time" - "fmt" - - "database/sql" + "gopkg.in/DATA-DOG/go-sqlmock.v1" "github.com/harlow/kinesis-consumer/checkpoint/postgres" "github.com/pkg/errors" - "gopkg.in/DATA-DOG/go-sqlmock.v1" ) func TestNew(t *testing.T) { @@ -77,6 +76,7 @@ func TestNew_WithMaxIntervalOption(t *testing.T) { } func TestCheckpoint_Get(t *testing.T) { + ctx := context.TODO() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -102,7 +102,7 @@ func TestCheckpoint_Get(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnRows(expectedRows) - gotSequenceNumber, err := ck.Get(streamName, shardID) + gotSequenceNumber, err := ck.Get(ctx, streamName, shardID) if gotSequenceNumber != expectedSequenceNumber { t.Errorf("expected sequence number equals %v, but got %v", expectedSequenceNumber, gotSequenceNumber) @@ -117,6 +117,7 @@ func TestCheckpoint_Get(t *testing.T) { } func TestCheckpoint_Get_NoRows(t *testing.T) { + ctx := context.TODO() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -138,7 +139,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(sql.ErrNoRows) - gotSequenceNumber, err := ck.Get(streamName, shardID) + gotSequenceNumber, err := ck.Get(ctx, streamName, shardID) if gotSequenceNumber != "" { t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) @@ -153,6 +154,7 @@ func TestCheckpoint_Get_NoRows(t *testing.T) { } func TestCheckpoint_Get_QueryError(t *testing.T) { + ctx := context.TODO() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -174,7 +176,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) { tableName) mock.ExpectQuery(expectedSQLRegexString).WithArgs(namespace, shardID).WillReturnError(errors.New("an error")) - gotSequenceNumber, err := ck.Get(streamName, shardID) + gotSequenceNumber, err := ck.Get(ctx, streamName, shardID) if gotSequenceNumber != "" { t.Errorf("expected sequence number equals empty, but got %v", gotSequenceNumber) @@ -189,6 +191,7 @@ func TestCheckpoint_Get_QueryError(t *testing.T) { } func TestCheckpoint_Set(t *testing.T) { + ctx := context.TODO() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -201,7 +204,7 @@ func TestCheckpoint_Set(t *testing.T) { t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) } - err = ck.Set(streamName, shardID, expectedSequenceNumber) + err = ck.Set(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Errorf("expected error equals nil, but got %v", err) @@ -210,6 +213,7 @@ func TestCheckpoint_Set(t *testing.T) { } func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { + ctx := context.TODO() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -222,7 +226,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { t.Fatalf("error occurred during the checkpoint creation. cause: %v", err) } - err = ck.Set(streamName, shardID, expectedSequenceNumber) + err = ck.Set(ctx, streamName, shardID, expectedSequenceNumber) if err == nil { t.Errorf("expected error equals not nil, but got %v", err) @@ -231,6 +235,7 @@ func TestCheckpoint_Set_SequenceNumberEmpty(t *testing.T) { } func TestCheckpoint_Shutdown(t *testing.T) { + ctx := context.TODO() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -253,7 +258,7 @@ func TestCheckpoint_Shutdown(t *testing.T) { result := sqlmock.NewResult(0, 1) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnResult(result) - err = ck.Set(streamName, shardID, expectedSequenceNumber) + err = ck.Set(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) @@ -270,6 +275,7 @@ func TestCheckpoint_Shutdown(t *testing.T) { } func TestCheckpoint_Shutdown_SaveError(t *testing.T) { + ctx := context.TODO() appName := "streamConsumer" tableName := "checkpoint" connString := "UserID=root;Password=myPassword;Host=localhost;Port=5432;Database=myDataBase;" @@ -291,7 +297,7 @@ func TestCheckpoint_Shutdown_SaveError(t *testing.T) { expectedSQLRegexString := fmt.Sprintf(`INSERT INTO %s \(namespace, shard_id, sequence_number\) VALUES\(\$1, \$2, \$3\) ON CONFLICT \(namespace, shard_id\) DO UPDATE SET sequence_number= \$3;`, tableName) mock.ExpectExec(expectedSQLRegexString).WithArgs(namespace, shardID, expectedSequenceNumber).WillReturnError(errors.New("an error")) - err = ck.Set(streamName, shardID, expectedSequenceNumber) + err = ck.Set(ctx, streamName, shardID, expectedSequenceNumber) if err != nil { t.Fatalf("unable to set checkpoint for data initialization. cause: %v", err) diff --git a/checkpoint/redis/redis.go b/checkpoint/redis/redis.go index e3f7e51..60f1387 100644 --- a/checkpoint/redis/redis.go +++ b/checkpoint/redis/redis.go @@ -1,6 +1,7 @@ package redis import ( + "context" "fmt" "os" @@ -37,14 +38,14 @@ type Checkpoint struct { } // Get fetches the checkpoint for a particular Shard. -func (c *Checkpoint) Get(streamName, shardID string) (string, error) { +func (c *Checkpoint) Get(ctx context.Context, streamName, shardID string) (string, error) { val, _ := c.client.Get(c.key(streamName, shardID)).Result() return val, nil } // Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application). // Upon failover, record processing is resumed from this point. -func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error { +func (c *Checkpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error { if sequenceNumber == "" { return fmt.Errorf("sequence number should not be empty") } diff --git a/checkpoint/redis/redis_test.go b/checkpoint/redis/redis_test.go index 5c0d02d..578cfae 100644 --- a/checkpoint/redis/redis_test.go +++ b/checkpoint/redis/redis_test.go @@ -1,21 +1,23 @@ package redis import ( + "context" "testing" ) func Test_CheckpointLifecycle(t *testing.T) { // new + ctx := context.TODO() c, err := New("app") if err != nil { t.Fatalf("new checkpoint error: %v", err) } // set - c.Set("streamName", "shardID", "testSeqNum") + c.Set(ctx, "streamName", "shardID", "testSeqNum") // get - val, err := c.Get("streamName", "shardID") + val, err := c.Get(ctx, "streamName", "shardID") if err != nil { t.Fatalf("get checkpoint error: %v", err) } @@ -25,12 +27,13 @@ func Test_CheckpointLifecycle(t *testing.T) { } func Test_SetEmptySeqNum(t *testing.T) { + ctx := context.TODO() c, err := New("app") if err != nil { t.Fatalf("new checkpoint error: %v", err) } - err = c.Set("streamName", "shardID", "") + err = c.Set(ctx, "streamName", "shardID", "") if err == nil { t.Fatalf("should not allow empty sequence number") } diff --git a/consumer.go b/consumer.go index 8137f1f..4ba3809 100644 --- a/consumer.go +++ b/consumer.go @@ -112,6 +112,8 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error // get shard ids shardIDs, err := c.getShardIDs(ctx, c.streamName) + span.SetTag("stream.name", c.streamName) + span.SetTag("shard.count", len(shardIDs)) if err != nil { span.LogKV("get shardID error", err.Error(), "stream.name", c.streamName) ext.Error.Set(span, true) @@ -119,7 +121,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error } if len(shardIDs) == 0 { - span.LogKV("get shardID error", err.Error(), "stream.name", c.streamName, "shards.count", len(shardIDs)) + span.LogKV("stream.name", c.streamName, "shards.count", len(shardIDs)) ext.Error.Set(span, true) return fmt.Errorf("no shards available") } @@ -138,6 +140,7 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error if err := c.ScanShard(ctx, shardID, fn); err != nil { span.LogKV("scan shard error", err.Error(), "shardID", shardID) ext.Error.Set(span, true) + span.Finish() select { case errc <- fmt.Errorf("shard %s error: %v", shardID, err): // first error to occur @@ -166,7 +169,7 @@ func (c *Consumer) ScanShard( span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.scanshard") defer span.Finish() // get checkpoint - lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) + lastSeqNum, err := c.checkpoint.Get(ctx, c.streamName, shardID) if err != nil { span.LogKV("checkpoint error", err.Error(), "shardID", shardID) ext.Error.Set(span, true) @@ -195,7 +198,7 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str return nil default: span.SetTag("scan", "on") - resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{ + resp, err := c.client.GetRecordsWithContext(ctx, &kinesis.GetRecordsInput{ ShardIterator: shardIterator, }) @@ -203,6 +206,7 @@ func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum str shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) if err != nil { ext.Error.Set(span, true) + span.LogKV("get shard iterator error", err.Error()) return fmt.Errorf("get shard iterator error: %v", err) } continue @@ -243,7 +247,7 @@ func (c *Consumer) handleRecord(ctx context.Context, shardID string, r *Record, status := fn(r) if !status.SkipCheckpoint { span.LogKV("scan.state", status) - if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { + if err := c.checkpoint.Set(ctx, c.streamName, shardID, *r.SequenceNumber); err != nil { span.LogKV("checkpoint error", err.Error(), "stream.name", c.streamName, "shardID", shardID, "sequenceNumber", *r.SequenceNumber) ext.Error.Set(span, true) return false, err @@ -269,7 +273,7 @@ func (c *Consumer) getShardIDs(ctx context.Context, streamName string) ([]string span, ctx := opentracing.StartSpanFromContext(ctx, "consumer.getShardIDs") defer span.Finish() - resp, err := c.client.DescribeStream( + resp, err := c.client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ StreamName: aws.String(streamName), }, @@ -304,7 +308,7 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, la params.ShardIteratorType = aws.String("TRIM_HORIZON") } - resp, err := c.client.GetShardIterator(params) + resp, err := c.client.GetShardIteratorWithContext(ctx, params) if err != nil { span.LogKV("get shard error", err.Error()) ext.Error.Set(span, true) diff --git a/consumer_test.go b/consumer_test.go index 3f17373..5eb8a9e 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" ) @@ -19,6 +20,7 @@ func TestNew(t *testing.T) { } func TestConsumer_Scan(t *testing.T) { + ctx := context.TODO() records := []*kinesis.Record{ { Data: []byte("firstData"), @@ -30,18 +32,18 @@ func TestConsumer_Scan(t *testing.T) { }, } client := &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: records, }, nil }, - describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + describeStreamMock: func(a aws.Context, input *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) { return &kinesis.DescribeStreamOutput{ StreamDescription: &kinesis.StreamDescription{ Shards: []*kinesis.Shard{ @@ -73,7 +75,7 @@ func TestConsumer_Scan(t *testing.T) { return ScanStatus{} } - if err := c.Scan(context.Background(), fn); err != nil { + if err := c.Scan(ctx, fn); err != nil { t.Errorf("scan shard error expected nil. got %v", err) } @@ -87,15 +89,16 @@ func TestConsumer_Scan(t *testing.T) { t.Errorf("counter error expected %d, got %d", 2, val) } - val, err := cp.Get("myStreamName", "myShard") + val, err := cp.Get(ctx, "myStreamName", "myShard") if err != nil && val != "lastSeqNum" { t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val) } } func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { + ctx := context.TODO() client := &kinesisClientMock{ - describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + describeStreamMock: func(a aws.Context, input *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) { return &kinesis.DescribeStreamOutput{ StreamDescription: &kinesis.StreamDescription{ Shards: make([]*kinesis.Shard, 0), @@ -123,7 +126,7 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { return ScanStatus{} } - if err := c.Scan(context.Background(), fn); err == nil { + if err := c.Scan(ctx, fn); err == nil { t.Errorf("scan shard error expected not nil. got %v", err) } @@ -133,13 +136,14 @@ func TestConsumer_Scan_NoShardsAvailable(t *testing.T) { if val := ctr.counter; val != 0 { t.Errorf("counter error expected %d, got %d", 0, val) } - val, err := cp.Get("myStreamName", "myShard") + val, err := cp.Get(ctx, "myStreamName", "myShard") if err != nil && val != "" { t.Errorf("checkout error expected %s, got %s", "", val) } } func TestScanShard(t *testing.T) { + ctx := context.TODO() var records = []*kinesis.Record{ { Data: []byte("firstData"), @@ -152,12 +156,12 @@ func TestScanShard(t *testing.T) { } var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: records, @@ -187,7 +191,7 @@ func TestScanShard(t *testing.T) { } // scan shard - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { + if err := c.ScanShard(ctx, "myShard", fn); err != nil { t.Fatalf("scan shard error: %v", err) } @@ -202,13 +206,14 @@ func TestScanShard(t *testing.T) { } // sets checkpoint - val, err := cp.Get("myStreamName", "myShard") + val, err := cp.Get(ctx, "myStreamName", "myShard") if err != nil && val != "lastSeqNum" { t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val) } } func TestScanShard_StopScan(t *testing.T) { + ctx := context.TODO() var records = []*kinesis.Record{ { Data: []byte("firstData"), @@ -221,12 +226,12 @@ func TestScanShard_StopScan(t *testing.T) { } var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: records, @@ -246,7 +251,7 @@ func TestScanShard_StopScan(t *testing.T) { return ScanStatus{StopScan: true} } - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { + if err := c.ScanShard(ctx, "myShard", fn); err != nil { t.Fatalf("scan shard error: %v", err) } @@ -256,13 +261,14 @@ func TestScanShard_StopScan(t *testing.T) { } func TestScanShard_ShardIsClosed(t *testing.T) { + ctx := context.TODO() var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(a aws.Context, input *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(a aws.Context, input *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: make([]*Record, 0), @@ -279,28 +285,28 @@ func TestScanShard_ShardIsClosed(t *testing.T) { return ScanStatus{} } - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { + if err := c.ScanShard(ctx, "myShard", fn); err != nil { t.Fatalf("scan shard error: %v", err) } } type kinesisClientMock struct { kinesisiface.KinesisAPI - getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) - getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) - describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) + getShardIteratorMock func(aws.Context, *kinesis.GetShardIteratorInput, ...request.Option) (*kinesis.GetShardIteratorOutput, error) + getRecordsMock func(aws.Context, *kinesis.GetRecordsInput, ...request.Option) (*kinesis.GetRecordsOutput, error) + describeStreamMock func(aws.Context, *kinesis.DescribeStreamInput, ...request.Option) (*kinesis.DescribeStreamOutput, error) } -func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { - return c.getRecordsMock(in) +func (c *kinesisClientMock) GetRecordsWithContext(a aws.Context, in *kinesis.GetRecordsInput, o ...request.Option) (*kinesis.GetRecordsOutput, error) { + return c.getRecordsMock(a, in, o...) } -func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { - return c.getShardIteratorMock(in) +func (c *kinesisClientMock) GetShardIteratorWithContext(a aws.Context, in *kinesis.GetShardIteratorInput, o ...request.Option) (*kinesis.GetShardIteratorOutput, error) { + return c.getShardIteratorMock(a, in, o...) } -func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { - return c.describeStreamMock(in) +func (c *kinesisClientMock) DescribeStreamWithContext(a aws.Context, in *kinesis.DescribeStreamInput, o ...request.Option) (*kinesis.DescribeStreamOutput, error) { + return c.describeStreamMock(a, in, o...) } // implementation of checkpoint @@ -309,7 +315,7 @@ type fakeCheckpoint struct { mu sync.Mutex } -func (fc *fakeCheckpoint) Set(streamName, shardID, sequenceNumber string) error { +func (fc *fakeCheckpoint) Set(ctx context.Context, streamName, shardID, sequenceNumber string) error { fc.mu.Lock() defer fc.mu.Unlock() @@ -318,7 +324,7 @@ func (fc *fakeCheckpoint) Set(streamName, shardID, sequenceNumber string) error return nil } -func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) { +func (fc *fakeCheckpoint) Get(ctx context.Context, streamName, shardID string) (string, error) { fc.mu.Lock() defer fc.mu.Unlock() diff --git a/examples/consumer/cp-dynamo/main.go b/examples/consumer/cp-dynamo/main.go index 23a6959..4edd7ab 100644 --- a/examples/consumer/cp-dynamo/main.go +++ b/examples/consumer/cp-dynamo/main.go @@ -16,14 +16,19 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/opentracing/opentracing-go" alog "github.com/apex/log" "github.com/apex/log/handlers/text" consumer "github.com/harlow/kinesis-consumer" checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb" + + "github.com/harlow/kinesis-consumer/examples/distributed-tracing/utility" ) +const serviceName = "checkpoint.dynamodb" + // kick off a server for exposing scan metrics func init() { sock, err := net.Listen("tcp", "localhost:8080") @@ -62,6 +67,12 @@ func main() { ) flag.Parse() + tracer, closer := utility.NewTracer(serviceName) + defer closer.Close() + opentracing.InitGlobalTracer(tracer) + span := tracer.StartSpan("consumer.main") + ctx := opentracing.ContextWithSpan(context.Background(), span) + // Following will overwrite the default dynamodb client // Older versions of aws sdk does not picking up aws config properly. // You probably need to update aws sdk verison. Tested the following with 1.13.59 @@ -72,7 +83,7 @@ func main() { ) // ddb checkpoint - ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{})) + ck, err := checkpoint.New(ctx, *app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{})) if err != nil { log.Log("checkpoint error: %v", err) } @@ -121,7 +132,7 @@ func main() { log.Log("scan error: %v", err) } - if err := ck.Shutdown(); err != nil { + if err := ck.Shutdown(ctx); err != nil { log.Log("checkpoint shutdown error: %v", err) } } diff --git a/examples/distributed-tracing/consumer/consumer.go b/examples/distributed-tracing/consumer/consumer.go index 03c7cbe..691329e 100644 --- a/examples/distributed-tracing/consumer/consumer.go +++ b/examples/distributed-tracing/consumer/consumer.go @@ -47,17 +47,17 @@ func main() { span := tracer.StartSpan("consumer.main") defer span.Finish() - var ( - app = flag.String("app", "", "App name") - stream = flag.String("stream", "", "Stream name") - table = flag.String("table", "", "Checkpoint table name") - ) + app := flag.String("app", "", "App name") + stream := flag.String("stream", "", "Stream name") + table := flag.String("table", "", "Checkpoint table name") flag.Parse() span.SetTag("app.name", app) span.SetTag("stream.name", stream) span.SetTag("table.name", table) + fmt.Println("set tag....") + // Following will overwrite the default dynamodb client // Older versions of aws sdk does not picking up aws config properly. // You probably need to update aws sdk verison. Tested the following with 1.13.59 @@ -67,8 +67,9 @@ func main() { myDynamoDbClient := dynamodb.New(sess) // ddb checkpoint + ctx := opentracing.ContextWithSpan(context.Background(), span) retryer := utility.NewRetryer() - ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(retryer)) + ck, err := checkpoint.New(ctx, *app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(retryer)) if err != nil { span.LogKV("checkpoint error", err.Error()) span.SetTag("consumer.retry.count", retryer.Count()) @@ -97,7 +98,8 @@ func main() { } // use cancel func to signal shutdown - ctx, cancel := context.WithCancel(context.Background()) + ctx = opentracing.ContextWithSpan(ctx, span) + ctx, cancel := context.WithCancel(ctx) // trap SIGINT, wait to trigger shutdown signals := make(chan os.Signal, 1) @@ -105,23 +107,25 @@ func main() { go func() { <-signals + span.Finish() + closer.Close() cancel() }() // scan stream err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus { fmt.Println(string(r.Data)) - // continue scanning return consumer.ScanStatus{} }) if err != nil { - span.LogKV("consumer scan error", err.Error()) - ext.Error.Set(span, true) + + //span.LogKV("consumer scan error", err.Error()) + //ext.Error.Set(span, true) log.Log("consumer scan error", "error", err.Error()) } - if err := ck.Shutdown(); err != nil { + if err := ck.Shutdown(ctx); err != nil { span.LogKV("consumer shutdown error", err.Error()) ext.Error.Set(span, true) log.Log("checkpoint shutdown error", "error", err.Error()) diff --git a/examples/distributed-tracing/producer/producer.go b/examples/distributed-tracing/producer/producer.go index d8f1d99..5bf7ff8 100644 --- a/examples/distributed-tracing/producer/producer.go +++ b/examples/distributed-tracing/producer/producer.go @@ -56,7 +56,7 @@ func main() { // Need to end span here, since Fatalf calls os.Exit span.Finish() closer.Close() - log.Fatal(fmt.Sprintf("Cannot open %s file"), dataFile) + log.Fatal(fmt.Sprintf("Cannot open %s file", dataFile)) } defer f.Close() span.SetTag("producer.file.name", f.Name()) @@ -88,11 +88,11 @@ func main() { func putRecords(ctx context.Context, streamName *string, records []*kinesis.PutRecordsRequestEntry) { // I am assuming each new AWS call is a new Span - span, _ := opentracing.StartSpanFromContext(ctx, "producer.putRecords") + span, ctx := opentracing.StartSpanFromContext(ctx, "producer.putRecords") defer span.Finish() span.SetTag("producer.records.count", len(records)) ctx = opentracing.ContextWithSpan(ctx, span) - _, err := svc.PutRecordsWithContext(&kinesis.PutRecordsInput{ + _, err := svc.PutRecordsWithContext(ctx, &kinesis.PutRecordsInput{ StreamName: streamName, Records: records, })