Update ddb checkpoint item to use dynamodbav (#144)

Update the DynamoDB checkpoint item to use the `dynamodbav` value for marshaling. 

Fixes: https://github.com/harlow/kinesis-consumer/issues/142

Minor changes:
* Update the DDB example consumer to support a local version of DDB for streamlined testing
This commit is contained in:
Harlow Ward 2021-12-04 13:40:26 -08:00 committed by GitHub
parent 3b3b252fa5
commit 6cbda0f706
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 40 deletions

View file

@ -10,6 +10,7 @@ import (
"net/http"
"os"
"os/signal"
"time"
alog "github.com/apex/log"
"github.com/apex/log/handlers/text"
@ -17,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
ddbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
consumer "github.com/harlow/kinesis-consumer"
@ -57,38 +59,33 @@ func main() {
var (
app = flag.String("app", "", "Consumer app name")
stream = flag.String("stream", "", "Stream name")
table = flag.String("table", "", "Checkpoint table name")
kinesisEndpoint = flag.String("endpoint", "http://localhost:4567", "Kinesis endpoint")
tableName = flag.String("table", "", "Checkpoint table name")
ddbEndpoint = flag.String("ddb-endpoint", "http://localhost:8000", "DynamoDB endpoint")
kinesisEndpoint = flag.String("ksis-endpoint", "http://localhost:4567", "Kinesis endpoint")
awsRegion = flag.String("region", "us-west-2", "AWS Region")
)
flag.Parse()
resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
return aws.Endpoint{
PartitionID: "aws",
URL: *kinesisEndpoint,
SigningRegion: *awsRegion,
}, nil
})
// client
cfg, err := config.LoadDefaultConfig(
context.TODO(),
config.WithRegion(*awsRegion),
config.WithEndpointResolver(resolver),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("user", "pass", "token")),
)
// set up clients
kcfg, err := newConfig(*kinesisEndpoint, *awsRegion)
if err != nil {
log.Fatalf("unable to load SDK config, %v", err)
log.Fatalf("new kinesis config error: %v", err)
}
var myKsis = kinesis.NewFromConfig(kcfg)
dcfg, err := newConfig(*ddbEndpoint, *awsRegion)
if err != nil {
log.Fatalf("new ddb config error: %v", err)
}
var myDdbClient = dynamodb.NewFromConfig(dcfg)
// ddb checkpoint table
if err := createTable(myDdbClient, *tableName); err != nil {
log.Fatalf("create ddb table error: %v", err)
}
var (
myDdbClient = dynamodb.NewFromConfig(cfg)
myKsis = kinesis.NewFromConfig(cfg)
)
// ddb persitance
ddb, err := storage.New(*app, *table, storage.WithDynamoClient(myDdbClient), storage.WithRetryer(&MyRetryer{}))
ddb, err := storage.New(*app, *tableName, storage.WithDynamoClient(myDdbClient), storage.WithRetryer(&MyRetryer{}))
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
@ -134,6 +131,50 @@ func main() {
}
}
func createTable(client *dynamodb.Client, tableName string) error {
resp, err := client.ListTables(context.Background(), &dynamodb.ListTablesInput{})
if err != nil {
return fmt.Errorf("list streams error: %v", err)
}
for _, val := range resp.TableNames {
if tableName == val {
return nil
}
}
_, err = client.CreateTable(
context.Background(),
&dynamodb.CreateTableInput{
TableName: aws.String(tableName),
AttributeDefinitions: []ddbtypes.AttributeDefinition{
{AttributeName: aws.String("namespace"), AttributeType: "S"},
{AttributeName: aws.String("shard_id"), AttributeType: "S"},
},
KeySchema: []ddbtypes.KeySchemaElement{
{AttributeName: aws.String("namespace"), KeyType: ddbtypes.KeyTypeHash},
{AttributeName: aws.String("shard_id"), KeyType: ddbtypes.KeyTypeRange},
},
ProvisionedThroughput: &ddbtypes.ProvisionedThroughput{
ReadCapacityUnits: aws.Int64(1),
WriteCapacityUnits: aws.Int64(1),
},
},
)
if err != nil {
return err
}
waiter := dynamodb.NewTableExistsWaiter(client)
return waiter.Wait(
context.Background(),
&dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
},
5*time.Second,
)
}
// MyRetryer used for storage
type MyRetryer struct {
storage.Retryer
@ -147,3 +188,20 @@ func (r *MyRetryer) ShouldRetry(err error) bool {
}
return false
}
func newConfig(url, region string) (aws.Config, error) {
resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
return aws.Endpoint{
PartitionID: "aws",
URL: url,
SigningRegion: region,
}, nil
})
return config.LoadDefaultConfig(
context.TODO(),
config.WithRegion(region),
config.WithEndpointResolver(resolver),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("user", "pass", "token")),
)
}

2
go.mod
View file

@ -5,7 +5,7 @@ require (
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/alicebob/miniredis v2.5.0+incompatible
github.com/apex/log v1.6.0
github.com/aws/aws-sdk-go-v2 v1.9.0
github.com/aws/aws-sdk-go-v2 v1.11.2
github.com/aws/aws-sdk-go-v2/config v1.6.1
github.com/aws/aws-sdk-go-v2/credentials v1.3.3 // indirect
github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.2.0 // indirect

4
go.sum
View file

@ -26,6 +26,8 @@ github.com/aws/aws-sdk-go-v2 v1.8.1 h1:GcFgQl7MsBygmeeqXyV1ivrTEmsVz/rdFJaTcltG9
github.com/aws/aws-sdk-go-v2 v1.8.1/go.mod h1:xEFuWz+3TYdlPRuo+CqATbeDWIWyaT5uAPwPaWtgse0=
github.com/aws/aws-sdk-go-v2 v1.9.0 h1:+S+dSqQCN3MSU5vJRu1HqHrq00cJn6heIMU7X9hcsoo=
github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
github.com/aws/aws-sdk-go-v2 v1.11.2 h1:SDiCYqxdIYi6HgQfAWRhgdZrdnOuGyLDJVRSWLeHWvs=
github.com/aws/aws-sdk-go-v2 v1.11.2/go.mod h1:SQfA+m2ltnu1cA0soUkj4dRSsmITiVQUJvBIZjzfPyQ=
github.com/aws/aws-sdk-go-v2/config v1.6.1 h1:qrZINaORyr78syO1zfD4l7r4tZjy0Z1l0sy4jiysyOM=
github.com/aws/aws-sdk-go-v2/config v1.6.1/go.mod h1:t/y3UPu0XEDy0cEw6mvygaBQaPzWiYAxfP2SzgtvclA=
github.com/aws/aws-sdk-go-v2/credentials v1.3.3 h1:A13QPatmUl41SqUfnuT3V0E3XiNGL6qNTOINbE8cZL4=
@ -60,6 +62,8 @@ github.com/aws/smithy-go v1.7.0 h1:+cLHMRrDZvQ4wk+KuQ9yH6eEg6KZEJ9RI2IkDqnygCg=
github.com/aws/smithy-go v1.7.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E=
github.com/aws/smithy-go v1.8.0 h1:AEwwwXQZtUwP5Mz506FeXXrKBe0jA8gVM+1gEcSRooc=
github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E=
github.com/aws/smithy-go v1.9.0 h1:c7FUdEqrQA1/UVKKCNDFQPNKGp4FQg3YW4Ck5SLTG58=
github.com/aws/smithy-go v1.9.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E=
github.com/awslabs/kinesis-aggregation/go v0.0.0-20200810181507-d352038274c0 h1:D97PNkeea5i2Sbq844BdbULqI5pv7yQw4thPwqEX504=
github.com/awslabs/kinesis-aggregation/go v0.0.0-20200810181507-d352038274c0/go.mod h1:SghidfnxvX7ribW6nHI7T+IBbc9puZ9kk5Tx/88h8P4=
github.com/awslabs/kinesis-aggregation/go v0.0.0-20210630091500-54e17340d32f h1:Pf0BjJDga7C98f0vhw+Ip5EaiE07S3lTKpIYPNS0nMo=

View file

@ -81,14 +81,14 @@ type Checkpoint struct {
}
type key struct {
streamName string `json:"stream_name"`
shardID string `json:"shard_id"`
StreamName string
ShardID string
}
type item struct {
Namespace string `json:"namespace"`
ShardID string `json:"shard_id"`
SequenceNumber string `json:"sequence_number"`
Namespace string `json:"namespace" dynamodbav:"namespace"`
ShardID string `json:"shard_id" dynamodbav:"shard_id"`
SequenceNumber string `json:"sequence_number" dynamodbav:"sequence_number"`
}
// GetCheckpoint determines if a checkpoint for a particular Shard exists.
@ -101,12 +101,8 @@ func (c *Checkpoint) GetCheckpoint(streamName, shardID string) (string, error) {
TableName: aws.String(c.tableName),
ConsistentRead: aws.Bool(true),
Key: map[string]types.AttributeValue{
"namespace": &types.AttributeValueMemberS{
Value: namespace,
},
"shard_id": &types.AttributeValueMemberS{
Value: shardID,
},
"namespace": &types.AttributeValueMemberS{Value: namespace},
"shard_id": &types.AttributeValueMemberS{Value: shardID},
},
}
@ -134,8 +130,8 @@ func (c *Checkpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) e
}
key := key{
streamName: streamName,
shardID: shardID,
StreamName: streamName,
ShardID: shardID,
}
c.checkpoints[key] = sequenceNumber
@ -169,8 +165,8 @@ func (c *Checkpoint) save() error {
for key, sequenceNumber := range c.checkpoints {
item, err := attributevalue.MarshalMap(item{
Namespace: fmt.Sprintf("%s-%s", c.appName, key.streamName),
ShardID: key.shardID,
Namespace: fmt.Sprintf("%s-%s", c.appName, key.StreamName),
ShardID: key.ShardID,
SequenceNumber: sequenceNumber,
})
if err != nil {