diff --git a/README.md b/README.md index 9a4f680..de3c6a7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # Golang Kinesis Consumer +Note: This repo has been upgraded to use AWS SDK v2. If you are still using +AWS SDK V1: https://github.com/harlow/kinesis-consumer/releases/tag/v0.3.5 + ![technology Go](https://img.shields.io/badge/technology-go-blue.svg) [![Build Status](https://travis-ci.com/harlow/kinesis-consumer.svg?branch=master)](https://travis-ci.com/harlow/kinesis-consumer) [![GoDoc](https://godoc.org/github.com/harlow/kinesis-consumer?status.svg)](https://godoc.org/github.com/harlow/kinesis-consumer) [![GoReportCard](https://goreportcard.com/badge/github.com/harlow/kinesis-consumer)](https://goreportcard.com/report/harlow/kinesis-consumer) Kinesis consumer applications written in Go. This library is intended to be a lightweight wrapper around the Kinesis API to read records, save checkpoints (with swappable backends), and gracefully recover from service timeouts/errors. diff --git a/allgroup.go b/allgroup.go index 9b803db..749a380 100644 --- a/allgroup.go +++ b/allgroup.go @@ -5,16 +5,15 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" ) // NewAllGroup returns an intitialized AllGroup for consuming // all shards on a stream -func NewAllGroup(ksis kinesisiface.KinesisAPI, store Store, streamName string, logger Logger) *AllGroup { +func NewAllGroup(ksis kinesisClient, store Store, streamName string, logger Logger) *AllGroup { return &AllGroup{ ksis: ksis, - shards: make(map[string]*kinesis.Shard), + shards: make(map[string]types.Shard), streamName: streamName, logger: logger, Store: store, @@ -25,37 +24,37 @@ func NewAllGroup(ksis kinesisiface.KinesisAPI, store Store, streamName string, l // caches a local list of the shards we are already processing // and routinely polls the stream looking for new shards to process. type AllGroup struct { - ksis kinesisiface.KinesisAPI + ksis kinesisClient streamName string logger Logger Store shardMu sync.Mutex - shards map[string]*kinesis.Shard + shards map[string]types.Shard } // Start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) { - var ticker = time.NewTicker(30 * time.Second) - g.findNewShards(shardc) - +func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) { // Note: while ticker is a rather naive approach to this problem, - // it actually simplies a few things. i.e. If we miss a new shard while - // AWS is resharding we'll pick it up max 30 seconds later. + // it actually simplifies a few things. i.e. If we miss a new shard + // while AWS is resharding we'll pick it up max 30 seconds later. // It might be worth refactoring this flow to allow the consumer to // to notify the broker when a shard is closed. However, shards don't // necessarily close at the same time, so we could potentially get a // thundering heard of notifications from the consumer. + var ticker = time.NewTicker(30 * time.Second) + for { + g.findNewShards(ctx, shardc) + select { case <-ctx.Done(): ticker.Stop() return case <-ticker.C: - g.findNewShards(shardc) } } } @@ -63,13 +62,13 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) { // findNewShards pulls the list of shards from the Kinesis API // and uses a local cache to determine if we are already processing // a particular shard. -func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) { +func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) { g.shardMu.Lock() defer g.shardMu.Unlock() g.logger.Log("[GROUP]", "fetching shards") - shards, err := listShards(g.ksis, g.streamName) + shards, err := listShards(ctx, g.ksis, g.streamName) if err != nil { g.logger.Log("[GROUP] error:", err) return diff --git a/client.go b/client.go new file mode 100644 index 0000000..af604ba --- /dev/null +++ b/client.go @@ -0,0 +1,14 @@ +package consumer + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/kinesis" +) + +// kinesisClient defines the interface of functions needed for the consumer +type kinesisClient interface { + GetRecords(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) + ListShards(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) + GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) +} diff --git a/cmd/consumer-dynamo/main.go b/cmd/consumer-dynamo/main.go index 4449689..6eacf77 100644 --- a/cmd/consumer-dynamo/main.go +++ b/cmd/consumer-dynamo/main.go @@ -13,11 +13,12 @@ import ( alog "github.com/apex/log" "github.com/apex/log/handlers/text" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" consumer "github.com/harlow/kinesis-consumer" storage "github.com/harlow/kinesis-consumer/store/ddb" ) @@ -46,7 +47,7 @@ func (l *myLogger) Log(args ...interface{}) { func main() { // Wrap myLogger around apex logger - log := &myLogger{ + mylog := &myLogger{ logger: alog.Logger{ Handler: text.New(os.Stdout), Level: alog.DebugLevel, @@ -62,24 +63,34 @@ func main() { ) flag.Parse() - // New Kinesis and DynamoDB clients (if you need custom config) - sess, err := session.NewSession(aws.NewConfig()) - if err != nil { - log.Log("new session error: %v", err) - } - myDdbClient := dynamodb.New(sess) + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: *kinesisEndpoint, + SigningRegion: *awsRegion, + }, nil + }) - var myKsis = kinesis.New(session.Must(session.NewSession( - aws.NewConfig(). - WithEndpoint(*kinesisEndpoint). - WithRegion(*awsRegion). - WithLogLevel(3), - ))) + // client + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(*awsRegion), + config.WithEndpointResolver(resolver), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("user", "pass", "token")), + ) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + + var ( + myDdbClient = dynamodb.NewFromConfig(cfg) + myKsis = kinesis.NewFromConfig(cfg) + ) // ddb persitance ddb, err := storage.New(*app, *table, storage.WithDynamoClient(myDdbClient), storage.WithRetryer(&MyRetryer{})) if err != nil { - log.Log("checkpoint error: %v", err) + log.Fatalf("checkpoint error: %v", err) } // expvar counter @@ -89,12 +100,12 @@ func main() { c, err := consumer.New( *stream, consumer.WithStore(ddb), - consumer.WithLogger(log), + consumer.WithLogger(mylog), consumer.WithCounter(counter), consumer.WithClient(myKsis), ) if err != nil { - log.Log("consumer error: %v", err) + log.Fatalf("consumer error: %v", err) } // use cancel func to signal shutdown @@ -115,11 +126,11 @@ func main() { return nil // continue scanning }) if err != nil { - log.Log("scan error: %v", err) + log.Fatalf("scan error: %v", err) } if err := ddb.Shutdown(); err != nil { - log.Log("storage shutdown error: %v", err) + log.Fatalf("storage shutdown error: %v", err) } } @@ -130,13 +141,9 @@ type MyRetryer struct { // ShouldRetry implements custom logic for when errors should retry func (r *MyRetryer) ShouldRetry(err error) bool { - if awsErr, ok := err.(awserr.Error); ok { - switch awsErr.Code() { - case dynamodb.ErrCodeProvisionedThroughputExceededException, dynamodb.ErrCodeLimitExceededException: - return true - default: - return false - } + switch err.(type) { + case *types.ProvisionedThroughputExceededException, *types.LimitExceededException: + return true } return false } diff --git a/cmd/consumer-mysql/main.go b/cmd/consumer-mysql/main.go index 2b67e04..d698989 100644 --- a/cmd/consumer-mysql/main.go +++ b/cmd/consumer-mysql/main.go @@ -9,9 +9,10 @@ import ( "os" "os/signal" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/kinesis" consumer "github.com/harlow/kinesis-consumer" store "github.com/harlow/kinesis-consumer/store/mysql" ) @@ -35,13 +36,25 @@ func main() { var counter = expvar.NewMap("counters") - // client - cfg := aws.NewConfig(). - WithEndpoint(*kinesisEndpoint). - WithRegion(*awsRegion). - WithLogLevel(3) + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: *kinesisEndpoint, + SigningRegion: *awsRegion, + }, nil + }) - var client = kinesis.New(session.Must(session.NewSession(cfg))) + // client + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(*awsRegion), + config.WithEndpointResolver(resolver), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("user", "pass", "token")), + ) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + var client = kinesis.NewFromConfig(cfg) // consumer c, err := consumer.New( diff --git a/cmd/consumer-postgres/main.go b/cmd/consumer-postgres/main.go index 525cc42..b51a8d5 100644 --- a/cmd/consumer-postgres/main.go +++ b/cmd/consumer-postgres/main.go @@ -9,9 +9,10 @@ import ( "os" "os/signal" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/kinesis" consumer "github.com/harlow/kinesis-consumer" store "github.com/harlow/kinesis-consumer/store/postgres" ) @@ -35,13 +36,25 @@ func main() { var counter = expvar.NewMap("counters") - // client - cfg := aws.NewConfig(). - WithEndpoint(*kinesisEndpoint). - WithRegion(*awsRegion). - WithLogLevel(3) + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: *kinesisEndpoint, + SigningRegion: *awsRegion, + }, nil + }) - var client = kinesis.New(session.Must(session.NewSession(cfg))) + // client + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(*awsRegion), + config.WithEndpointResolver(resolver), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("user", "pass", "token")), + ) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + var client = kinesis.NewFromConfig(cfg) // consumer c, err := consumer.New( diff --git a/cmd/consumer-redis/main.go b/cmd/consumer-redis/main.go index 5e6fbb8..e997247 100644 --- a/cmd/consumer-redis/main.go +++ b/cmd/consumer-redis/main.go @@ -8,9 +8,10 @@ import ( "os" "os/signal" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/kinesis" consumer "github.com/harlow/kinesis-consumer" store "github.com/harlow/kinesis-consumer/store/redis" ) @@ -45,13 +46,25 @@ func main() { logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags), } - // client - cfg := aws.NewConfig(). - WithEndpoint(*kinesisEndpoint). - WithRegion(*awsRegion). - WithLogLevel(3) + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: *kinesisEndpoint, + SigningRegion: *awsRegion, + }, nil + }) - var client = kinesis.New(session.Must(session.NewSession(cfg))) + // client + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(*awsRegion), + config.WithEndpointResolver(resolver), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("user", "pass", "token")), + ) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + var client = kinesis.NewFromConfig(cfg) // consumer c, err := consumer.New( diff --git a/cmd/consumer/main.go b/cmd/consumer/main.go index ba5d197..d447bd1 100644 --- a/cmd/consumer/main.go +++ b/cmd/consumer/main.go @@ -9,9 +9,10 @@ import ( "os/signal" "syscall" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/kinesis" consumer "github.com/harlow/kinesis-consumer" ) @@ -33,12 +34,25 @@ func main() { ) flag.Parse() + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: *kinesisEndpoint, + SigningRegion: *awsRegion, + }, nil + }) + // client - var client = kinesis.New(session.Must(session.NewSession( - aws.NewConfig(). - WithEndpoint(*kinesisEndpoint). - WithRegion(*awsRegion), - ))) + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(*awsRegion), + config.WithEndpointResolver(resolver), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("user", "pass", "token")), + ) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + var client = kinesis.NewFromConfig(cfg) // consumer c, err := consumer.New( diff --git a/cmd/producer/main.go b/cmd/producer/main.go index 798acfd..e714973 100644 --- a/cmd/producer/main.go +++ b/cmd/producer/main.go @@ -2,15 +2,18 @@ package main import ( "bufio" + "context" "flag" "fmt" "log" "os" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" ) func main() { @@ -21,17 +24,29 @@ func main() { ) flag.Parse() - var records []*kinesis.PutRecordsRequestEntry + var records []types.PutRecordsRequestEntry - var client = kinesis.New(session.Must(session.NewSession( - aws.NewConfig(). - WithEndpoint(*kinesisEndpoint). - WithRegion(*awsRegion). - WithLogLevel(3), - ))) + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: *kinesisEndpoint, + SigningRegion: *awsRegion, + }, nil + }) + + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(*awsRegion), + config.WithEndpointResolver(resolver), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("user", "pass", "token")), + ) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + var client = kinesis.NewFromConfig(cfg) // create stream if doesn't exist - if err := createStream(client, streamName); err != nil { + if err := createStream(client, *streamName); err != nil { log.Fatalf("create stream error: %v", err) } @@ -39,7 +54,7 @@ func main() { b := bufio.NewScanner(os.Stdin) for b.Scan() { - records = append(records, &kinesis.PutRecordsRequestEntry{ + records = append(records, types.PutRecordsRequestEntry{ Data: b.Bytes(), PartitionKey: aws.String(time.Now().Format(time.RFC3339Nano)), }) @@ -55,37 +70,41 @@ func main() { } } -func createStream(client *kinesis.Kinesis, streamName *string) error { - resp, err := client.ListStreams(&kinesis.ListStreamsInput{}) +func createStream(client *kinesis.Client, streamName string) error { + resp, err := client.ListStreams(context.Background(), &kinesis.ListStreamsInput{}) if err != nil { return fmt.Errorf("list streams error: %v", err) } for _, val := range resp.StreamNames { - if *streamName == *val { + if streamName == val { return nil } } _, err = client.CreateStream( + context.Background(), &kinesis.CreateStreamInput{ - StreamName: streamName, - ShardCount: aws.Int64(2), + StreamName: aws.String(streamName), + ShardCount: aws.Int32(2), }, ) if err != nil { return err } - return client.WaitUntilStreamExists( + waiter := kinesis.NewStreamExistsWaiter(client) + return waiter.Wait( + context.Background(), &kinesis.DescribeStreamInput{ - StreamName: streamName, + StreamName: aws.String(streamName), }, + 30*time.Second, ) } -func putRecords(client *kinesis.Kinesis, streamName *string, records []*kinesis.PutRecordsRequestEntry) { - _, err := client.PutRecords(&kinesis.PutRecordsInput{ +func putRecords(client *kinesis.Client, streamName *string, records []types.PutRecordsRequestEntry) { + _, err := client.PutRecords(context.Background(), &kinesis.PutRecordsInput{ StreamName: streamName, Records: records, }) diff --git a/consumer.go b/consumer.go index 7d155b6..bff0b21 100644 --- a/consumer.go +++ b/consumer.go @@ -9,18 +9,17 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" - "github.com/awslabs/kinesis-aggregation/go/deaggregator" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" + "github.com/harlow/kinesis-consumer/internal/deaggregator" ) // Record wraps the record returned from the Kinesis library and // extends to include the shard id. type Record struct { - *kinesis.Record + types.Record ShardID string MillisBehindLatest *int64 } @@ -35,7 +34,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) { // new consumer with noop storage, counter, and logger c := &Consumer{ streamName: streamName, - initialShardIteratorType: kinesis.ShardIteratorTypeLatest, + initialShardIteratorType: types.ShardIteratorTypeLatest, store: &noopStore{}, counter: &noopCounter{}, logger: &noopLogger{ @@ -52,11 +51,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) { // default client if c.client == nil { - newSession, err := session.NewSession(aws.NewConfig()) + cfg, err := config.LoadDefaultConfig(context.TODO()) if err != nil { - return nil, err + log.Fatalf("unable to load SDK config, %v", err) } - c.client = kinesis.New(newSession) + c.client = kinesis.NewFromConfig(cfg) } // default group consumes all shards @@ -70,9 +69,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) { // Consumer wraps the interaction with the Kinesis stream type Consumer struct { streamName string - initialShardIteratorType string + initialShardIteratorType types.ShardIteratorType initialTimestamp *time.Time - client kinesisiface.KinesisAPI + client kinesisClient counter Counter group Group logger Logger @@ -104,7 +103,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { var ( errc = make(chan error, 1) - shardc = make(chan *kinesis.Shard, 1) + shardc = make(chan types.Shard, 1) ) go func() { @@ -128,7 +127,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // error has already occurred } } - }(aws.StringValue(shard.ShardId)) + }(aws.ToString(shard.ShardId)) } go func() { @@ -158,23 +157,22 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e defer func() { c.logger.Log("[CONSUMER] stop scan:", shardID) }() + scanTicker := time.NewTicker(c.scanInterval) defer scanTicker.Stop() for { - resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{ - Limit: aws.Int64(c.maxRecords), + resp, err := c.client.GetRecords(ctx, &kinesis.GetRecordsInput{ + Limit: aws.Int32(int32(c.maxRecords)), ShardIterator: shardIterator, }) - // attempt to recover from GetRecords error when expired iterator + // attempt to recover from GetRecords error if err != nil { c.logger.Log("[CONSUMER] get records error:", err.Error()) - if awserr, ok := err.(awserr.Error); ok { - if _, ok := retriableErrors[awserr.Code()]; !ok { - return fmt.Errorf("get records error: %v", awserr.Message()) - } + if !isRetriableError(err) { + return fmt.Errorf("get records error: %v", err.Error()) } shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) @@ -183,16 +181,18 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } } else { // loop over records, call callback func - var records []*kinesis.Record - var err error + var records []types.Record + + // deaggregate records if c.isAggregated { - records, err = deaggregator.DeaggregateRecords(resp.Records) + records, err = deaggregateRecords(resp.Records) if err != nil { return err } } else { records = resp.Records } + for _, r := range records { select { case <-ctx.Done(): @@ -216,8 +216,10 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e if isShardClosed(resp.NextShardIterator, shardIterator) { c.logger.Log("[CONSUMER] shard closed:", shardID) + if c.shardClosedHandler != nil { err := c.shardClosedHandler(c.streamName, shardID) + if err != nil { return fmt.Errorf("shard closed handler error: %w", err) } @@ -238,14 +240,23 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } } -var retriableErrors = map[string]struct{}{ - kinesis.ErrCodeExpiredIteratorException: struct{}{}, - kinesis.ErrCodeProvisionedThroughputExceededException: struct{}{}, - kinesis.ErrCodeInternalFailureException: struct{}{}, -} +// temporary conversion func of []types.Record -> DeaggregateRecords([]*types.Record) -> []types.Record +func deaggregateRecords(in []types.Record) ([]types.Record, error) { + var recs []*types.Record + for _, rec := range in { + recs = append(recs, &rec) + } -func isShardClosed(nextShardIterator, currentShardIterator *string) bool { - return nextShardIterator == nil || currentShardIterator == nextShardIterator + deagg, err := deaggregator.DeaggregateRecords(recs) + if err != nil { + return nil, err + } + + var out []types.Record + for _, rec := range deagg { + out = append(out, *rec) + } + return out, nil } func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) { @@ -255,15 +266,29 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } if seqNum != "" { - params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber) + params.ShardIteratorType = types.ShardIteratorTypeAfterSequenceNumber params.StartingSequenceNumber = aws.String(seqNum) } else if c.initialTimestamp != nil { - params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAtTimestamp) + params.ShardIteratorType = types.ShardIteratorTypeAtTimestamp params.Timestamp = c.initialTimestamp } else { - params.ShardIteratorType = aws.String(c.initialShardIteratorType) + params.ShardIteratorType = types.ShardIteratorType(c.initialShardIteratorType) } - res, err := c.client.GetShardIteratorWithContext(aws.Context(ctx), params) + res, err := c.client.GetShardIterator(ctx, params) return res.ShardIterator, err } + +func isRetriableError(err error) bool { + switch err.(type) { + case *types.ExpiredIteratorException: + return true + case *types.ProvisionedThroughputExceededException: + return true + } + return false +} + +func isShardClosed(nextShardIterator, currentShardIterator *string) bool { + return nextShardIterator == nil || currentShardIterator == nextShardIterator +} diff --git a/consumer_test.go b/consumer_test.go index b7e5c1d..3330d32 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -6,16 +6,14 @@ import ( "sync" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" - "github.com/harlow/kinesis-consumer/store/memory" + store "github.com/harlow/kinesis-consumer/store/memory" ) -var records = []*kinesis.Record{ +var records = []types.Record{ { Data: []byte("firstData"), SequenceNumber: aws.String("firstSeqNum"), @@ -34,20 +32,20 @@ func TestNew(t *testing.T) { func TestScan(t *testing.T) { client := &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: records, }, nil }, - listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { + listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { return &kinesis.ListShardsOutput{ - Shards: []*kinesis.Shard{ + Shards: []types.Shard{ {ShardId: aws.String("myShard")}, }, }, nil @@ -102,12 +100,12 @@ func TestScan(t *testing.T) { func TestScanShard(t *testing.T) { var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: records, @@ -168,12 +166,12 @@ func TestScanShard(t *testing.T) { func TestScanShard_Cancellation(t *testing.T) { var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: records, @@ -208,12 +206,12 @@ func TestScanShard_Cancellation(t *testing.T) { func TestScanShard_SkipCheckpoint(t *testing.T) { var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: records, @@ -231,7 +229,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) { var ctx, cancel = context.WithCancel(context.Background()) var fn = func(r *Record) error { - if aws.StringValue(r.SequenceNumber) == "lastSeqNum" { + if aws.ToString(r.SequenceNumber) == "lastSeqNum" { cancel() return ErrSkipCheckpoint } @@ -252,15 +250,15 @@ func TestScanShard_SkipCheckpoint(t *testing.T) { func TestScanShard_ShardIsClosed(t *testing.T) { var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, - Records: make([]*kinesis.Record, 0), + Records: make([]types.Record, 0), }, nil }, } @@ -282,15 +280,15 @@ func TestScanShard_ShardIsClosed(t *testing.T) { func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) { var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, - Records: make([]*kinesis.Record, 0), + Records: make([]types.Record, 0), }, nil }, } @@ -319,20 +317,17 @@ func TestScanShard_ShardIsClosed_WithShardClosedHandler(t *testing.T) { func TestScanShard_GetRecordsError(t *testing.T) { var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { + getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { return &kinesis.GetShardIteratorOutput{ ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), }, nil }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { + getRecordsMock: func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { return &kinesis.GetRecordsOutput{ NextShardIterator: nil, Records: nil, - }, awserr.New( - kinesis.ErrCodeInvalidArgumentException, - "aws error message", - fmt.Errorf("error message"), - ) + }, + &types.InvalidArgumentException{Message: aws.String("aws error message")} }, } @@ -346,32 +341,28 @@ func TestScanShard_GetRecordsError(t *testing.T) { } err = c.ScanShard(context.Background(), "myShard", fn) - if err.Error() != "get records error: aws error message" { + if err.Error() != "get records error: InvalidArgumentException: aws error message" { t.Fatalf("unexpected error: %v", err) } } type kinesisClientMock struct { - kinesisiface.KinesisAPI - getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) - getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) - listShardsMock func(*kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) + kinesis.Client + getShardIteratorMock func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) + getRecordsMock func(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) + listShardsMock func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) } -func (c *kinesisClientMock) ListShards(in *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { - return c.listShardsMock(in) +func (c *kinesisClientMock) ListShards(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) { + return c.listShardsMock(ctx, params) } -func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { - return c.getRecordsMock(in) +func (c *kinesisClientMock) GetRecords(ctx context.Context, params *kinesis.GetRecordsInput, optFns ...func(*kinesis.Options)) (*kinesis.GetRecordsOutput, error) { + return c.getRecordsMock(ctx, params) } -func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { - return c.getShardIteratorMock(in) -} - -func (c *kinesisClientMock) GetShardIteratorWithContext(ctx aws.Context, in *kinesis.GetShardIteratorInput, options ...request.Option) (*kinesis.GetShardIteratorOutput, error) { - return c.getShardIteratorMock(in) +func (c *kinesisClientMock) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) { + return c.getShardIteratorMock(ctx, params) } // implementation of counter diff --git a/go.mod b/go.mod index 69c654b..bab7ee6 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,24 @@ require ( github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect github.com/alicebob/miniredis v2.5.0+incompatible github.com/apex/log v1.6.0 - github.com/aws/aws-sdk-go v1.33.7 - github.com/awslabs/kinesis-aggregation/go v0.0.0-20200810181507-d352038274c0 + github.com/aws/aws-sdk-go-v2 v1.9.0 + github.com/aws/aws-sdk-go-v2/config v1.6.1 + github.com/aws/aws-sdk-go-v2/credentials v1.3.3 // indirect + github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.5.0 // indirect + github.com/aws/aws-sdk-go-v2/service/kinesis v1.6.0 + github.com/awslabs/kinesis-aggregation/go v0.0.0-20210630091500-54e17340d32f github.com/go-redis/redis/v8 v8.0.0-beta.6 github.com/go-sql-driver/mysql v1.5.0 + github.com/golang/protobuf v1.5.2 // indirect github.com/gomodule/redigo v2.0.0+incompatible // indirect + github.com/google/go-cmp v0.5.6 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/lib/pq v1.7.0 github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.7.0 // indirect github.com/yuin/gopher-lua v0.0.0-20200603152657-dc2b0ca8b37e // indirect + google.golang.org/protobuf v1.27.1 // indirect ) go 1.13 diff --git a/go.sum b/go.sum index 1a74f9d..06e70fb 100644 --- a/go.sum +++ b/go.sum @@ -18,11 +18,52 @@ github.com/apex/logs v1.0.0/go.mod h1:XzxuLZ5myVHDy9SAmYpamKKRNApGj54PfYLcFrXqDw github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy8kCu4PNA+aP7WUV72eXWJeP9/r3/K9aLE= github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys= github.com/aws/aws-sdk-go v1.19.48/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go v1.20.6 h1:kmy4Gvdlyez1fV4kw5RYxZzWKVyuHZHgPWeU/YvRsV4= github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.33.7 h1:vOozL5hmWHHriRviVTQnUwz8l05RS0rehmEFymI+/x8= -github.com/aws/aws-sdk-go v1.33.7/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go v1.40.27 h1:8fWW0CpmBZ8WWduNwl4vE9t07nMYFrhAsUHjPj81qUM= +github.com/aws/aws-sdk-go v1.40.27/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q= +github.com/aws/aws-sdk-go-v2 v1.8.1 h1:GcFgQl7MsBygmeeqXyV1ivrTEmsVz/rdFJaTcltG9ag= +github.com/aws/aws-sdk-go-v2 v1.8.1/go.mod h1:xEFuWz+3TYdlPRuo+CqATbeDWIWyaT5uAPwPaWtgse0= +github.com/aws/aws-sdk-go-v2 v1.9.0 h1:+S+dSqQCN3MSU5vJRu1HqHrq00cJn6heIMU7X9hcsoo= +github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4= +github.com/aws/aws-sdk-go-v2/config v1.6.1 h1:qrZINaORyr78syO1zfD4l7r4tZjy0Z1l0sy4jiysyOM= +github.com/aws/aws-sdk-go-v2/config v1.6.1/go.mod h1:t/y3UPu0XEDy0cEw6mvygaBQaPzWiYAxfP2SzgtvclA= +github.com/aws/aws-sdk-go-v2/credentials v1.3.3 h1:A13QPatmUl41SqUfnuT3V0E3XiNGL6qNTOINbE8cZL4= +github.com/aws/aws-sdk-go-v2/credentials v1.3.3/go.mod h1:oVieKMT3m9BSfqhOfuQ+E0j/yN84ZAJ7Qv8Sfume/ak= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.2.0 h1:8kvinmbIDObqsWegKP0JjeanYPiA4GUVpAtciNWE+jw= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.2.0/go.mod h1:UVFtSYSWCHj2+brBLDHUdlJXmz8LxUpZhA+Ewypc+xQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 h1:rc+fRGvlKbeSd9IFhFS1KWBs0XjTkq0CfK5xqyLgIp0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1/go.mod h1:+GTydg3uHmVlQdkRoetz6VHKbOMEYof70m19IpMLifc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.0.4 h1:IM9b6hlCcVFJFydPoyphs/t7YrHfqKy7T4/7AG5Eprs= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.0.4/go.mod h1:W5gGbtNXFpF9/ssYZTaItzG/B+j0bjTnwStiCP2AtWU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 h1:IkqRRUZTKaS16P2vpX+FNc2jq3JWa3c478gykQp4ow4= +github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1/go.mod h1:Pv3WenDjI0v2Jl7UaMFIIbPOBbhn33RmmAmGgkXDoqY= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.5.0 h1:SGwKUQaJudQQZE72dDQlL2FGuHNAEK1CyqKLTjh6mqE= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.5.0/go.mod h1:XY5YhCS9SLul3JSQ08XG/nfxXxrkh6RR21XPq/J//NY= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.4.0 h1:QbFWJr2SAyVYvyoOHvJU6sCGLnqNT94ZbWElJMEI1JY= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.4.0/go.mod h1:bYsEP8w5YnbYyrx/Zi5hy4hTwRRQISSJS3RWrsGRijg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.3.0 h1:gceOysEWNNwLd6cki65IMBZ4WAM0MwgBQq2n7kejoT8= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.3.0/go.mod h1:v8ygadNyATSm6elwJ/4gzJwcFhri9RqS8skgHKiwXPU= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.1.0 h1:QCPbsMPMcM4iGbui5SH6O4uxvZffPoBJ4CIGX7dU0l4= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.1.0/go.mod h1:enkU5tq2HoXY+ZMiQprgF3Q83T3PbO77E83yXXzRZWE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 h1:VxFCgxsqWe7OThOwJ5IpFX3xrObtuIH9Hg/NW7oot1Y= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3/go.mod h1:7gcsONBmFoCcKrAqrm95trrMd2+C/ReYKP7Vfu8yHHA= +github.com/aws/aws-sdk-go-v2/service/kinesis v1.5.3 h1:ScnzrRDDZsARgwEXFBuE8cQ4rkm2yuaA72Gad4NNVK8= +github.com/aws/aws-sdk-go-v2/service/kinesis v1.5.3/go.mod h1:1C+FM8Sk3+UoI/svy0J9CS+e4PbJ/qsXxlE9k4/nQKI= +github.com/aws/aws-sdk-go-v2/service/kinesis v1.6.0 h1:hb+NupVMUzINGUCfDs2+YqMkWKu47dBIQHpulM0XWh4= +github.com/aws/aws-sdk-go-v2/service/kinesis v1.6.0/go.mod h1:9O7UG2pELnP0hq35+Gd7XDjOLBkg7tmgRQ0y14ZjoJI= +github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 h1:K2gCnGvAASpz+jqP9iyr+F/KNjmTYf8aWOtTQzhmZ5w= +github.com/aws/aws-sdk-go-v2/service/sso v1.3.3/go.mod h1:Jgw5O+SK7MZ2Yi9Yvzb4PggAPYaFSliiQuWR0hNjexk= +github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 h1:l504GWCoQi1Pk68vSUFGLmDIEMzRfVGNgLakDK+Uj58= +github.com/aws/aws-sdk-go-v2/service/sts v1.6.2/go.mod h1:RBhoMJB8yFToaCnbe0jNq5Dcdy0jp6LhHqg55rjClkM= +github.com/aws/smithy-go v1.7.0 h1:+cLHMRrDZvQ4wk+KuQ9yH6eEg6KZEJ9RI2IkDqnygCg= +github.com/aws/smithy-go v1.7.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= +github.com/aws/smithy-go v1.8.0 h1:AEwwwXQZtUwP5Mz506FeXXrKBe0jA8gVM+1gEcSRooc= +github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/awslabs/kinesis-aggregation/go v0.0.0-20200810181507-d352038274c0 h1:D97PNkeea5i2Sbq844BdbULqI5pv7yQw4thPwqEX504= github.com/awslabs/kinesis-aggregation/go v0.0.0-20200810181507-d352038274c0/go.mod h1:SghidfnxvX7ribW6nHI7T+IBbc9puZ9kk5Tx/88h8P4= +github.com/awslabs/kinesis-aggregation/go v0.0.0-20210630091500-54e17340d32f h1:Pf0BjJDga7C98f0vhw+Ip5EaiE07S3lTKpIYPNS0nMo= +github.com/awslabs/kinesis-aggregation/go v0.0.0-20210630091500-54e17340d32f/go.mod h1:SghidfnxvX7ribW6nHI7T+IBbc9puZ9kk5Tx/88h8P4= github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I= github.com/benbjohnson/clock v1.0.3 h1:vkLuvpK4fmtSCuo60+yC63p7y0BmQ8gm5ZXGuBCJyXg= github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM= @@ -34,7 +75,6 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -65,22 +105,30 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= -github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jpillora/backoff v0.0.0-20180909062703-3050d21c67d7/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -99,7 +147,6 @@ github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyex github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo= github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.5.0 h1:izbySO9zDPmjJ8rDjLvkA2zJHIo+HkYXHnf7eN7SSyo= github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= @@ -118,11 +165,11 @@ github.com/smartystreets/gunit v1.0.0/go.mod h1:qwPWnhz6pn0NnRBP++URONOVyNkPyr4S github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLDRpvE+3b7gP/C2YyLFYxNmcLnPTMe0= github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= @@ -156,9 +203,9 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -173,9 +220,13 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY= golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -203,6 +254,10 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -210,11 +265,11 @@ gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/group.go b/group.go index aa08438..a092dc3 100644 --- a/group.go +++ b/group.go @@ -3,12 +3,12 @@ package consumer import ( "context" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" ) // Group interface used to manage which shard to process type Group interface { - Start(ctx context.Context, shardc chan *kinesis.Shard) + Start(ctx context.Context, shardc chan types.Shard) GetCheckpoint(streamName, shardID string) (string, error) SetCheckpoint(streamName, shardID, sequenceNumber string) error } diff --git a/internal/deaggregator/README.md b/internal/deaggregator/README.md new file mode 100644 index 0000000..ce474ad --- /dev/null +++ b/internal/deaggregator/README.md @@ -0,0 +1,6 @@ +# Temporary Deaggregator + +Upgrading to aws-sdk-go-v2 was blocked on a PR to introduce a new Deaggregator: +https://github.com/awslabs/kinesis-aggregation/pull/143/files + +Once that PR is merged I'll remove this code and pull in the `awslabs/kinesis-aggregation` repo. \ No newline at end of file diff --git a/internal/deaggregator/deaggregator.go b/internal/deaggregator/deaggregator.go new file mode 100644 index 0000000..94782f1 --- /dev/null +++ b/internal/deaggregator/deaggregator.go @@ -0,0 +1,94 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package deaggregator + +import ( + "crypto/md5" + "fmt" + + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" + "github.com/golang/protobuf/proto" + + rec "github.com/awslabs/kinesis-aggregation/go/records" +) + +// Magic File Header for a KPL Aggregated Record +var KplMagicHeader = fmt.Sprintf("%q", []byte("\xf3\x89\x9a\xc2")) + +const ( + KplMagicLen = 4 // Length of magic header for KPL Aggregate Record checking. + DigestSize = 16 // MD5 Message size for protobuf. +) + +// DeaggregateRecords takes an array of Kinesis records and expands any Protobuf +// records within that array, returning an array of all records +func DeaggregateRecords(records []*types.Record) ([]*types.Record, error) { + var isAggregated bool + allRecords := make([]*types.Record, 0) + for _, record := range records { + isAggregated = true + + var dataMagic string + var decodedDataNoMagic []byte + // Check if record is long enough to have magic file header + if len(record.Data) >= KplMagicLen { + dataMagic = fmt.Sprintf("%q", record.Data[:KplMagicLen]) + decodedDataNoMagic = record.Data[KplMagicLen:] + } else { + isAggregated = false + } + + // Check if record has KPL Aggregate Record Magic Header and data length + // is correct size + if KplMagicHeader != dataMagic || len(decodedDataNoMagic) <= DigestSize { + isAggregated = false + } + + if isAggregated { + messageDigest := fmt.Sprintf("%x", decodedDataNoMagic[len(decodedDataNoMagic)-DigestSize:]) + messageData := decodedDataNoMagic[:len(decodedDataNoMagic)-DigestSize] + + calculatedDigest := fmt.Sprintf("%x", md5.Sum(messageData)) + + // Check protobuf MD5 hash matches MD5 sum of record + if messageDigest != calculatedDigest { + isAggregated = false + } else { + aggRecord := &rec.AggregatedRecord{} + err := proto.Unmarshal(messageData, aggRecord) + + if err != nil { + return nil, err + } + + partitionKeys := aggRecord.PartitionKeyTable + + for _, aggrec := range aggRecord.Records { + newRecord := createUserRecord(partitionKeys, aggrec, record) + allRecords = append(allRecords, newRecord) + } + } + } + + if !isAggregated { + allRecords = append(allRecords, record) + } + } + + return allRecords, nil +} + +// createUserRecord takes in the partitionKeys of the aggregated record, the individual +// deaggregated record, and the original aggregated record builds a kinesis.Record and +// returns it +func createUserRecord(partitionKeys []string, aggRec *rec.Record, record *types.Record) *types.Record { + partitionKey := partitionKeys[*aggRec.PartitionKeyIndex] + + return &types.Record{ + ApproximateArrivalTimestamp: record.ApproximateArrivalTimestamp, + Data: aggRec.Data, + EncryptionType: record.EncryptionType, + PartitionKey: &partitionKey, + SequenceNumber: record.SequenceNumber, + } +} diff --git a/internal/deaggregator/deaggregator_test.go b/internal/deaggregator/deaggregator_test.go new file mode 100644 index 0000000..d1c33f9 --- /dev/null +++ b/internal/deaggregator/deaggregator_test.go @@ -0,0 +1,202 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package deaggregator_test + +import ( + "crypto/md5" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + + rec "github.com/awslabs/kinesis-aggregation/go/records" + deagg "github.com/harlow/kinesis-consumer/internal/deaggregator" +) + +// Generate an aggregate record in the correct AWS-specified format +// https://github.com/awslabs/amazon-kinesis-producer/blob/master/aggregation-format.md +func generateAggregateRecord(numRecords int) []byte { + + aggr := &rec.AggregatedRecord{} + // Start with the magic header + aggRecord := []byte("\xf3\x89\x9a\xc2") + partKeyTable := make([]string, 0) + + // Create proto record with numRecords length + for i := 0; i < numRecords; i++ { + var partKey uint64 + var hashKey uint64 + partKey = uint64(i) + hashKey = uint64(i) * uint64(10) + r := &rec.Record{ + PartitionKeyIndex: &partKey, + ExplicitHashKeyIndex: &hashKey, + Data: []byte("Some test data string"), + Tags: make([]*rec.Tag, 0), + } + + aggr.Records = append(aggr.Records, r) + partKeyVal := "test" + fmt.Sprint(i) + partKeyTable = append(partKeyTable, partKeyVal) + } + + aggr.PartitionKeyTable = partKeyTable + // Marshal to protobuf record, create md5 sum from proto record + // and append both to aggRecord with magic header + data, _ := proto.Marshal(aggr) + md5Hash := md5.Sum(data) + aggRecord = append(aggRecord, data...) + aggRecord = append(aggRecord, md5Hash[:]...) + return aggRecord +} + +// Generate a generic kinesis.Record using whatever []byte +// is passed in as the data (can be normal []byte or proto record) +func generateKinesisRecord(data []byte) *types.Record { + currentTime := time.Now() + encryptionType := types.EncryptionTypeNone + partitionKey := "1234" + sequenceNumber := "21269319989900637946712965403778482371" + return &types.Record{ + ApproximateArrivalTimestamp: ¤tTime, + Data: data, + EncryptionType: encryptionType, + PartitionKey: &partitionKey, + SequenceNumber: &sequenceNumber, + } +} + +// This tests to make sure that the data is at least larger than the length +// of the magic header to do some array slicing with index out of bounds +func TestSmallLengthReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) { + var err error + var kr *types.Record + + krs := make([]*types.Record, 0, 1) + + smallByte := []byte("No") + kr = generateKinesisRecord(smallByte) + krs = append(krs, kr) + dars, err := deagg.DeaggregateRecords(krs) + if err != nil { + panic(err) + } + + // Small byte test, since this is not a deaggregated record, should return 1 + // record in the array. + assert.Equal(t, 1, len(dars), "Small Byte test should return length of 1.") +} + +// This function tests to make sure that the data starts with the correct magic header +// according to KPL aggregate documentation. +func TestNonMatchingMagicHeaderReturnsSingleRecord(t *testing.T) { + var err error + var kr *types.Record + + krs := make([]*types.Record, 0, 1) + + min := 1 + max := 10 + n := rand.Intn(max-min) + min + aggData := generateAggregateRecord(n) + mismatchAggData := aggData[1:] + kr = generateKinesisRecord(mismatchAggData) + + krs = append(krs, kr) + + dars, err := deagg.DeaggregateRecords(krs) + if err != nil { + panic(err) + } + + // A byte record with a magic header that does not match 0xF3 0x89 0x9A 0xC2 + // should return a single record. + assert.Equal(t, 1, len(dars), "Mismatch magic header test should return length of 1.") +} + +// This function tests that the DeaggregateRecords function returns the correct number of +// deaggregated records from a single aggregated record. +func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) { + var err error + var kr *types.Record + + krs := make([]*types.Record, 0, 1) + + min := 1 + max := 10 + n := rand.Intn(max-min) + min + aggData := generateAggregateRecord(n) + kr = generateKinesisRecord(aggData) + krs = append(krs, kr) + + dars, err := deagg.DeaggregateRecords(krs) + if err != nil { + panic(err) + } + + // Variable Length Aggregate Record test has aggregaterd records and should return + // n length. + assertMsg := fmt.Sprintf("Variable Length Aggregate Record should return length %v.", len(dars)) + assert.Equal(t, n, len(dars), assertMsg) +} + +// This function tests the length of the message after magic file header. If length is less than +// the digest size (16 bytes), it is not an aggregated record. +func TestRecordAfterMagicHeaderWithLengthLessThanDigestSizeReturnsSingleRecord(t *testing.T) { + var err error + var kr *types.Record + + krs := make([]*types.Record, 0, 1) + + min := 1 + max := 10 + n := rand.Intn(max-min) + min + aggData := generateAggregateRecord(n) + // Change size of proto message to 15 + reducedAggData := aggData[:19] + kr = generateKinesisRecord(reducedAggData) + + krs = append(krs, kr) + + dars, err := deagg.DeaggregateRecords(krs) + if err != nil { + panic(err) + } + + // A byte record with length less than 16 after the magic header should return + // a single record from DeaggregateRecords + assert.Equal(t, 1, len(dars), "Digest size test should return length of 1.") +} + +// This function tests the MD5 Sum at the end of the record by comparing MD5 sum +// at end of proto record with MD5 Sum of Proto message. If they do not match, +// it is not an aggregated record. +func TestRecordWithMismatchMd5SumReturnsSingleRecord(t *testing.T) { + var err error + var kr *types.Record + + krs := make([]*types.Record, 0, 1) + + min := 1 + max := 10 + n := rand.Intn(max-min) + min + aggData := generateAggregateRecord(n) + // Remove last byte from array to mismatch the MD5 sums + mismatchAggData := aggData[:len(aggData)-1] + kr = generateKinesisRecord(mismatchAggData) + + krs = append(krs, kr) + + dars, err := deagg.DeaggregateRecords(krs) + if err != nil { + panic(err) + } + + // A byte record with an MD5 sum that does not match with the md5.Sum(record) + // will be marked as a non-aggregate record and return a single record + assert.Equal(t, 1, len(dars), "Mismatch md5 sum test should return length of 1.") +} diff --git a/kinesis.go b/kinesis.go index af02060..fcf1c46 100644 --- a/kinesis.go +++ b/kinesis.go @@ -1,22 +1,23 @@ package consumer import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" ) -// listShards pulls a list of shard IDs from the kinesis api -func listShards(ksis kinesisiface.KinesisAPI, streamName string) ([]*kinesis.Shard, error) { - var ss []*kinesis.Shard +// listShards pulls a list of Shard IDs from the kinesis api +func listShards(ctx context.Context, ksis kinesisClient, streamName string) ([]types.Shard, error) { + var ss []types.Shard var listShardsInput = &kinesis.ListShardsInput{ StreamName: aws.String(streamName), } for { - resp, err := ksis.ListShards(listShardsInput) + resp, err := ksis.ListShards(ctx, listShardsInput) if err != nil { return nil, fmt.Errorf("ListShards error: %w", err) } diff --git a/options.go b/options.go index c1080bf..355ad4c 100644 --- a/options.go +++ b/options.go @@ -3,7 +3,7 @@ package consumer import ( "time" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" ) // Option is used to override defaults when creating a new Consumer @@ -38,7 +38,7 @@ func WithCounter(counter Counter) Option { } // WithClient overrides the default client -func WithClient(client kinesisiface.KinesisAPI) Option { +func WithClient(client kinesisClient) Option { return func(c *Consumer) { c.client = client } @@ -47,7 +47,7 @@ func WithClient(client kinesisiface.KinesisAPI) Option { // WithShardIteratorType overrides the starting point for the consumer func WithShardIteratorType(t string) Option { return func(c *Consumer) { - c.initialShardIteratorType = t + c.initialShardIteratorType = types.ShardIteratorType(t) } } diff --git a/store/ddb/ddb.go b/store/ddb/ddb.go index 26ab3b0..5914702 100644 --- a/store/ddb/ddb.go +++ b/store/ddb/ddb.go @@ -1,16 +1,17 @@ package ddb import ( + "context" "fmt" "log" "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Option is used to override defaults when creating a new Checkpoint @@ -24,7 +25,7 @@ func WithMaxInterval(maxInterval time.Duration) Option { } // WithDynamoClient sets the dynamoDb client -func WithDynamoClient(svc dynamodbiface.DynamoDBAPI) Option { +func WithDynamoClient(svc *dynamodb.Client) Option { return func(c *Checkpoint) { c.client = svc } @@ -39,7 +40,6 @@ func WithRetryer(r Retryer) Option { // New returns a checkpoint that uses DynamoDB for underlying storage func New(appName, tableName string, opts ...Option) (*Checkpoint, error) { - ck := &Checkpoint{ tableName: tableName, appName: appName, @@ -56,11 +56,11 @@ func New(appName, tableName string, opts ...Option) (*Checkpoint, error) { // default client if ck.client == nil { - newSession, err := session.NewSession(aws.NewConfig()) + cfg, err := config.LoadDefaultConfig(context.TODO()) if err != nil { - return nil, err + log.Fatalf("unable to load SDK config, %v", err) } - ck.client = dynamodb.New(newSession) + ck.client = dynamodb.NewFromConfig(cfg) } go ck.loop() @@ -72,7 +72,7 @@ func New(appName, tableName string, opts ...Option) (*Checkpoint, error) { type Checkpoint struct { tableName string appName string - client dynamodbiface.DynamoDBAPI + client *dynamodb.Client maxInterval time.Duration mu *sync.Mutex // protects the checkpoints checkpoints map[key]string @@ -81,8 +81,8 @@ type Checkpoint struct { } type key struct { - streamName string - shardID string + streamName string `json:"stream_name"` + shardID string `json:"shard_id"` } type item struct { @@ -100,17 +100,17 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { params := &dynamodb.GetItemInput{ TableName: aws.String(c.tableName), ConsistentRead: aws.Bool(true), - Key: map[string]*dynamodb.AttributeValue{ - "namespace": &dynamodb.AttributeValue{ - S: aws.String(namespace), + Key: map[string]types.AttributeValue{ + "namespace": &types.AttributeValueMemberS{ + Value: namespace, }, - "shard_id": &dynamodb.AttributeValue{ - S: aws.String(shardID), + "shard_id": &types.AttributeValueMemberS{ + Value: shardID, }, }, } - resp, err := c.client.GetItem(params) + resp, err := c.client.GetItem(context.Background(), params) if err != nil { if c.retryer.ShouldRetry(err) { return c.GetCheckpoint(streamName, shardID) @@ -119,7 +119,7 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) { } var i item - dynamodbattribute.UnmarshalMap(resp.Item, &i) + attributevalue.UnmarshalMap(resp.Item, &i) return i.SequenceNumber, nil } @@ -168,7 +168,7 @@ func (c *Checkpoint) save() error { defer c.mu.Unlock() for key, sequenceNumber := range c.checkpoints { - item, err := dynamodbattribute.MarshalMap(item{ + item, err := attributevalue.MarshalMap(item{ Namespace: fmt.Sprintf("%s-%s", c.appName, key.streamName), ShardID: key.shardID, SequenceNumber: sequenceNumber, @@ -178,10 +178,12 @@ func (c *Checkpoint) save() error { return nil } - _, err = c.client.PutItem(&dynamodb.PutItemInput{ - TableName: aws.String(c.tableName), - Item: item, - }) + _, err = c.client.PutItem( + context.TODO(), + &dynamodb.PutItemInput{ + TableName: aws.String(c.tableName), + Item: item, + }) if err != nil { if !c.retryer.ShouldRetry(err) { return err diff --git a/store/ddb/ddb_test.go b/store/ddb/ddb_test.go index c921afc..8f8e720 100644 --- a/store/ddb/ddb_test.go +++ b/store/ddb/ddb_test.go @@ -1,12 +1,13 @@ package ddb import ( + "context" + "log" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" ) type fakeRetryer struct { @@ -42,11 +43,12 @@ func TestCheckpointSetting(t *testing.T) { setRetryer(ckPtr) // Test WithDyanmoDBClient - var fakeDbClient = dynamodb.New( - session.New(aws.NewConfig()), &aws.Config{ - Region: aws.String("us-west-2"), - }, - ) + cfg, err := config.LoadDefaultConfig(context.TODO()) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + var fakeDbClient = dynamodb.NewFromConfig(cfg) + setDDBClient := WithDynamoClient(fakeDbClient) setDDBClient(ckPtr) @@ -70,11 +72,12 @@ func TestNewCheckpointWithOptions(t *testing.T) { setRetryer := WithRetryer(&r) // Test WithDyanmoDBClient - var fakeDbClient = dynamodb.New( - session.New(aws.NewConfig()), &aws.Config{ - Region: aws.String("us-west-2"), - }, - ) + cfg, err := config.LoadDefaultConfig(context.TODO()) + if err != nil { + log.Fatalf("unable to load SDK config, %v", err) + } + var fakeDbClient = dynamodb.NewFromConfig(cfg) + setDDBClient := WithDynamoClient(fakeDbClient) ckPtr, err := New("testapp", "testtable", setInterval, setRetryer, setDDBClient) diff --git a/store/ddb/retryer.go b/store/ddb/retryer.go index 646bb6c..41da790 100644 --- a/store/ddb/retryer.go +++ b/store/ddb/retryer.go @@ -1,8 +1,7 @@ package ddb import ( - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) // Retryer interface contains one method that decides whether to retry based on error @@ -17,10 +16,9 @@ type DefaultRetryer struct { // ShouldRetry when error occured func (r *DefaultRetryer) ShouldRetry(err error) bool { - if awsErr, ok := err.(awserr.Error); ok { - if awsErr.Code() == dynamodb.ErrCodeProvisionedThroughputExceededException { - return true - } + switch err.(type) { + case *types.ProvisionedThroughputExceededException: + return true } return false } diff --git a/store/ddb/retryer_test.go b/store/ddb/retryer_test.go index 26d42a5..b7ac8d7 100644 --- a/store/ddb/retryer_test.go +++ b/store/ddb/retryer_test.go @@ -1,22 +1,21 @@ package ddb import ( - "errors" "testing" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) func TestDefaultRetyer(t *testing.T) { - retryableError := awserr.New(dynamodb.ErrCodeProvisionedThroughputExceededException, "error is retryable", errors.New("don't care what is here")) + retryableError := &types.ProvisionedThroughputExceededException{Message: aws.String("error not retryable")} // retryer is not nil and should returns according to what error is passed in. q := &DefaultRetryer{} if q.ShouldRetry(retryableError) != true { t.Errorf("expected ShouldRetry returns %v. got %v", false, q.ShouldRetry(retryableError)) } - nonRetryableError := awserr.New(dynamodb.ErrCodeBackupInUseException, "error is not retryable", errors.New("don't care what is here")) + nonRetryableError := &types.BackupInUseException{Message: aws.String("error not retryable")} shouldRetry := q.ShouldRetry(nonRetryableError) if shouldRetry != false { t.Errorf("expected ShouldRetry returns %v. got %v", true, shouldRetry)