Add support for lease stealing (#78)

Fixes #4

Signed-off-by: Connor McKelvey <connormckelvey@gmail.com>
Signed-off-by: Ali Hobbs <alisuehobbs@gmail.com>
Co-authored-by: Ali Hobbs <alisuehobbs@gmail.com>

Co-authored-by: Ali Hobbs <alisuehobbs@gmail.com>
This commit is contained in:
Connor McKelvey 2021-06-01 17:18:26 -06:00 committed by Tao Jiang
parent 4a642bfa2f
commit 7de4607b71
18 changed files with 1233 additions and 50 deletions

View file

@ -8,8 +8,8 @@ targets:
rebuild-toolchain:
description: build toolchain image
watches:
- support/docker/toolchain
build: support/docker/toolchain
- support/toolchain/docker
build: support/toolchain/docker
toolchain:
description: placeholder for additional toolchain dependencies

View file

@ -40,9 +40,13 @@ const (
LeaseTimeoutKey = "LeaseTimeout"
SequenceNumberKey = "Checkpoint"
ParentShardIdKey = "ParentShardId"
ClaimRequestKey = "ClaimRequest"
// We've completely processed all records in this shard.
ShardEnd = "SHARD_END"
// ErrShardClaimed is returned when shard is claimed
ErrShardClaimed = "Shard is already claimed by another node"
)
type ErrLeaseNotAcquired struct {
@ -72,7 +76,17 @@ type Checkpointer interface {
// RemoveLeaseOwner to remove lease owner for the shard entry to make the shard available for reassignment
RemoveLeaseOwner(string) error
// New Lease Stealing Methods
// ListActiveWorkers returns active workers and their shards
ListActiveWorkers(map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error)
// ClaimShard claims a shard for stealing
ClaimShard(*par.ShardStatus, string) error
}
// ErrSequenceIDNotFound is returned by FetchCheckpoint when no SequenceID is found
var ErrSequenceIDNotFound = errors.New("SequenceIDNotFoundForShard")
// ErrShardNotAssigned is returned by ListActiveWorkers when no AssignedTo is found
var ErrShardNotAssigned = errors.New("AssignedToNotFoundForShard")

View file

@ -28,6 +28,8 @@
package checkpoint
import (
"errors"
"fmt"
"time"
"github.com/aws/aws-sdk-go/aws"
@ -61,6 +63,7 @@ type DynamoCheckpoint struct {
svc dynamodbiface.DynamoDBAPI
kclConfig *config.KinesisClientLibConfiguration
Retries int
lastLeaseSync time.Time
}
func NewDynamoCheckpoint(kclConfig *config.KinesisClientLibConfiguration) *DynamoCheckpoint {
@ -124,8 +127,22 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *par.ShardStatus, newAssign
return err
}
isClaimRequestExpired := shard.IsClaimRequestExpired(checkpointer.kclConfig)
var claimRequest string
if checkpointer.kclConfig.EnableLeaseStealing {
if currentCheckpointClaimRequest, ok := currentCheckpoint[ClaimRequestKey]; ok && currentCheckpointClaimRequest.S != nil {
claimRequest = *currentCheckpointClaimRequest.S
if newAssignTo != claimRequest && !isClaimRequestExpired {
checkpointer.log.Debugf("another worker: %s has a claim on this shard. Not going to renew the lease", claimRequest)
return errors.New(ErrShardClaimed)
}
}
}
assignedVar, assignedToOk := currentCheckpoint[LeaseOwnerKey]
leaseVar, leaseTimeoutOk := currentCheckpoint[LeaseTimeoutKey]
var conditionalExpression string
var expressionAttributeValues map[string]*dynamodb.AttributeValue
@ -140,8 +157,14 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *par.ShardStatus, newAssign
return err
}
if time.Now().UTC().Before(currentLeaseTimeout) && assignedTo != newAssignTo {
return ErrLeaseNotAcquired{"current lease timeout not yet expired"}
if checkpointer.kclConfig.EnableLeaseStealing {
if time.Now().UTC().Before(currentLeaseTimeout) && assignedTo != newAssignTo && !isClaimRequestExpired {
return ErrLeaseNotAcquired{"current lease timeout not yet expired"}
}
} else {
if time.Now().UTC().Before(currentLeaseTimeout) && assignedTo != newAssignTo {
return ErrLeaseNotAcquired{"current lease timeout not yet expired"}
}
}
checkpointer.log.Debugf("Attempting to get a lock for shard: %s, leaseTimeout: %s, assignedTo: %s, newAssignedTo: %s", shard.ID, currentLeaseTimeout, assignedTo, newAssignTo)
@ -175,9 +198,21 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *par.ShardStatus, newAssign
marshalledCheckpoint[ParentShardIdKey] = &dynamodb.AttributeValue{S: aws.String(shard.ParentShardId)}
}
if shard.GetCheckpoint() != "" {
if checkpoint := shard.GetCheckpoint(); checkpoint != "" {
marshalledCheckpoint[SequenceNumberKey] = &dynamodb.AttributeValue{
S: aws.String(shard.GetCheckpoint()),
S: aws.String(checkpoint),
}
}
if checkpointer.kclConfig.EnableLeaseStealing {
if claimRequest != "" && claimRequest == newAssignTo && !isClaimRequestExpired {
if expressionAttributeValues == nil {
expressionAttributeValues = make(map[string]*dynamodb.AttributeValue)
}
conditionalExpression = conditionalExpression + " AND ClaimRequest = :claim_request"
expressionAttributeValues[":claim_request"] = &dynamodb.AttributeValue{
S: &claimRequest,
}
}
}
@ -199,7 +234,7 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *par.ShardStatus, newAssign
// CheckpointSequence writes a checkpoint at the designated sequence ID
func (checkpointer *DynamoCheckpoint) CheckpointSequence(shard *par.ShardStatus) error {
leaseTimeout := shard.LeaseTimeout.UTC().Format(time.RFC3339)
leaseTimeout := shard.GetLeaseTimeout().UTC().Format(time.RFC3339)
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
LeaseKeyKey: {
S: aws.String(shard.ID),
@ -208,7 +243,7 @@ func (checkpointer *DynamoCheckpoint) CheckpointSequence(shard *par.ShardStatus)
S: aws.String(shard.GetCheckpoint()),
},
LeaseOwnerKey: {
S: aws.String(shard.AssignedTo),
S: aws.String(shard.GetLeaseOwner()),
},
LeaseTimeoutKey: {
S: aws.String(leaseTimeout),
@ -239,6 +274,16 @@ func (checkpointer *DynamoCheckpoint) FetchCheckpoint(shard *par.ShardStatus) er
if assignedTo, ok := checkpoint[LeaseOwnerKey]; ok {
shard.SetLeaseOwner(aws.StringValue(assignedTo.S))
}
// Use up-to-date leaseTimeout to avoid ConditionalCheckFailedException when claiming
if leaseTimeout, ok := checkpoint[LeaseTimeoutKey]; ok && leaseTimeout.S != nil {
currentLeaseTimeout, err := time.Parse(time.RFC3339, aws.StringValue(leaseTimeout.S))
if err != nil {
return err
}
shard.LeaseTimeout = currentLeaseTimeout
}
return nil
}
@ -265,6 +310,12 @@ func (checkpointer *DynamoCheckpoint) RemoveLeaseOwner(shardID string) error {
},
},
UpdateExpression: aws.String("remove " + LeaseOwnerKey),
ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
":assigned_to": {
S: aws.String(checkpointer.kclConfig.WorkerID),
},
},
ConditionExpression: aws.String("AssignedTo = :assigned_to"),
}
_, err := checkpointer.svc.UpdateItem(input)
@ -272,6 +323,135 @@ func (checkpointer *DynamoCheckpoint) RemoveLeaseOwner(shardID string) error {
return err
}
// ListActiveWorkers returns a map of workers and their shards
func (checkpointer *DynamoCheckpoint) ListActiveWorkers(shardStatus map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) {
err := checkpointer.syncLeases(shardStatus)
if err != nil {
return nil, err
}
workers := map[string][]*par.ShardStatus{}
for _, shard := range shardStatus {
if shard.GetCheckpoint() == ShardEnd {
continue
}
leaseOwner := shard.GetLeaseOwner()
if leaseOwner == "" {
checkpointer.log.Debugf("Shard Not Assigned Error. ShardID: %s, WorkerID: %s", shard.ID, checkpointer.kclConfig.WorkerID)
return nil, ErrShardNotAssigned
}
if w, ok := workers[leaseOwner]; ok {
workers[leaseOwner] = append(w, shard)
} else {
workers[leaseOwner] = []*par.ShardStatus{shard}
}
}
return workers, nil
}
// ClaimShard places a claim request on a shard to signal a steal attempt
func (checkpointer *DynamoCheckpoint) ClaimShard(shard *par.ShardStatus, claimID string) error {
err := checkpointer.FetchCheckpoint(shard)
if err != nil && err != ErrSequenceIDNotFound {
return err
}
leaseTimeoutString := shard.GetLeaseTimeout().Format(time.RFC3339)
conditionalExpression := `ShardID = :id AND LeaseTimeout = :lease_timeout AND attribute_not_exists(ClaimRequest)`
expressionAttributeValues := map[string]*dynamodb.AttributeValue{
":id": {
S: aws.String(shard.ID),
},
":lease_timeout": {
S: aws.String(leaseTimeoutString),
},
}
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
LeaseKeyKey: {
S: &shard.ID,
},
LeaseTimeoutKey: {
S: &leaseTimeoutString,
},
SequenceNumberKey: {
S: &shard.Checkpoint,
},
ClaimRequestKey: {
S: &claimID,
},
}
if leaseOwner := shard.GetLeaseOwner(); leaseOwner == "" {
conditionalExpression += " AND attribute_not_exists(AssignedTo)"
} else {
marshalledCheckpoint[LeaseOwnerKey] = &dynamodb.AttributeValue{S: &leaseOwner}
conditionalExpression += "AND AssignedTo = :assigned_to"
expressionAttributeValues[":assigned_to"] = &dynamodb.AttributeValue{S: &leaseOwner}
}
if checkpoint := shard.GetCheckpoint(); checkpoint == "" {
conditionalExpression += " AND attribute_not_exists(Checkpoint)"
} else if checkpoint == ShardEnd {
conditionalExpression += " AND Checkpoint <> :checkpoint"
expressionAttributeValues[":checkpoint"] = &dynamodb.AttributeValue{S: aws.String(ShardEnd)}
} else {
conditionalExpression += " AND Checkpoint = :checkpoint"
expressionAttributeValues[":checkpoint"] = &dynamodb.AttributeValue{S: &checkpoint}
}
if shard.ParentShardId == "" {
conditionalExpression += " AND attribute_not_exists(ParentShardId)"
} else {
marshalledCheckpoint[ParentShardIdKey] = &dynamodb.AttributeValue{S: aws.String(shard.ParentShardId)}
conditionalExpression += " AND ParentShardId = :parent_shard"
expressionAttributeValues[":parent_shard"] = &dynamodb.AttributeValue{S: &shard.ParentShardId}
}
return checkpointer.conditionalUpdate(conditionalExpression, expressionAttributeValues, marshalledCheckpoint)
}
func (checkpointer *DynamoCheckpoint) syncLeases(shardStatus map[string]*par.ShardStatus) error {
log := checkpointer.kclConfig.Logger
if (checkpointer.lastLeaseSync.Add(time.Duration(checkpointer.kclConfig.LeaseSyncingTimeIntervalMillis) * time.Millisecond)).After(time.Now()) {
return nil
}
checkpointer.lastLeaseSync = time.Now()
input := &dynamodb.ScanInput{
ProjectionExpression: aws.String(fmt.Sprintf("%s,%s,%s", LeaseKeyKey, LeaseOwnerKey, SequenceNumberKey)),
Select: aws.String("SPECIFIC_ATTRIBUTES"),
TableName: aws.String(checkpointer.kclConfig.TableName),
}
err := checkpointer.svc.ScanPages(input,
func(pages *dynamodb.ScanOutput, lastPage bool) bool {
results := pages.Items
for _, result := range results {
shardId, foundShardId := result[LeaseKeyKey]
assignedTo, foundAssignedTo := result[LeaseOwnerKey]
checkpoint, foundCheckpoint := result[SequenceNumberKey]
if !foundShardId || !foundAssignedTo || !foundCheckpoint {
continue
}
if shard, ok := shardStatus[aws.StringValue(shardId.S)]; ok {
shard.SetLeaseOwner(aws.StringValue(assignedTo.S))
shard.SetCheckpoint(aws.StringValue(checkpoint.S))
}
}
return !lastPage
})
if err != nil {
log.Debugf("Error performing SyncLeases. Error: %+v ", err)
return err
}
log.Debugf("Lease sync completed. Next lease sync will occur in %s", time.Duration(checkpointer.kclConfig.LeaseSyncingTimeIntervalMillis)*time.Millisecond)
return nil
}
func (checkpointer *DynamoCheckpoint) createTable() error {
input := &dynamodb.CreateTableInput{
AttributeDefinitions: []*dynamodb.AttributeDefinition{

View file

@ -85,6 +85,7 @@ func TestGetLeaseNotAquired(t *testing.T) {
Checkpoint: "",
Mux: &sync.RWMutex{},
}, "ijkl-mnop")
if err == nil || !errors.As(err, &ErrLeaseNotAcquired{}) {
t.Errorf("Got a lease when it was already held by abcd-efgh: %s", err)
}
@ -102,16 +103,16 @@ func TestGetLeaseAquired(t *testing.T) {
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
checkpoint.Init()
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
"ShardID": {
LeaseKeyKey: {
S: aws.String("0001"),
},
"AssignedTo": {
LeaseOwnerKey: {
S: aws.String("abcd-efgh"),
},
"LeaseTimeout": {
LeaseTimeoutKey: {
S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)),
},
"SequenceID": {
SequenceNumberKey: {
S: aws.String("deadbeef"),
},
}
@ -156,10 +157,221 @@ func TestGetLeaseAquired(t *testing.T) {
assert.Equal(t, "", status.GetLeaseOwner())
}
func TestGetLeaseShardClaimed(t *testing.T) {
leaseTimeout := time.Now().Add(-100 * time.Second).UTC()
svc := &mockDynamoDB{
tableExist: true,
item: map[string]*dynamodb.AttributeValue{
ClaimRequestKey: {S: aws.String("ijkl-mnop")},
LeaseTimeoutKey: {S: aws.String(leaseTimeout.Format(time.RFC3339))},
},
}
kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc").
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
WithShardSyncIntervalMillis(5000).
WithFailoverTimeMillis(300000).
WithLeaseStealing(true)
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
checkpoint.Init()
err := checkpoint.GetLease(&par.ShardStatus{
ID: "0001",
Checkpoint: "",
LeaseTimeout: leaseTimeout,
Mux: &sync.RWMutex{},
}, "abcd-efgh")
if err == nil || err.Error() != ErrShardClaimed {
t.Errorf("Got a lease when it was already claimed by by ijkl-mnop: %s", err)
}
err = checkpoint.GetLease(&par.ShardStatus{
ID: "0001",
Checkpoint: "",
LeaseTimeout: leaseTimeout,
Mux: &sync.RWMutex{},
}, "ijkl-mnop")
if err != nil {
t.Errorf("Error getting lease %s", err)
}
}
func TestGetLeaseClaimRequestExpiredOwner(t *testing.T) {
kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc").
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
WithShardSyncIntervalMillis(5000).
WithFailoverTimeMillis(300000).
WithLeaseStealing(true)
// Not expired
leaseTimeout := time.Now().
Add(-time.Duration(kclConfig.LeaseStealingClaimTimeoutMillis) * time.Millisecond).
Add(1 * time.Second).
UTC()
svc := &mockDynamoDB{
tableExist: true,
item: map[string]*dynamodb.AttributeValue{
LeaseOwnerKey: {S: aws.String("abcd-efgh")},
ClaimRequestKey: {S: aws.String("ijkl-mnop")},
LeaseTimeoutKey: {S: aws.String(leaseTimeout.Format(time.RFC3339))},
},
}
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
checkpoint.Init()
err := checkpoint.GetLease(&par.ShardStatus{
ID: "0001",
Checkpoint: "",
LeaseTimeout: leaseTimeout,
Mux: &sync.RWMutex{},
}, "abcd-efgh")
if err == nil || err.Error() != ErrShardClaimed {
t.Errorf("Got a lease when it was already claimed by ijkl-mnop: %s", err)
}
}
func TestGetLeaseClaimRequestExpiredClaimer(t *testing.T) {
kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc").
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
WithShardSyncIntervalMillis(5000).
WithFailoverTimeMillis(300000).
WithLeaseStealing(true)
// Not expired
leaseTimeout := time.Now().
Add(-time.Duration(kclConfig.LeaseStealingClaimTimeoutMillis) * time.Millisecond).
Add(121 * time.Second).
UTC()
svc := &mockDynamoDB{
tableExist: true,
item: map[string]*dynamodb.AttributeValue{
LeaseOwnerKey: {S: aws.String("abcd-efgh")},
ClaimRequestKey: {S: aws.String("ijkl-mnop")},
LeaseTimeoutKey: {S: aws.String(leaseTimeout.Format(time.RFC3339))},
},
}
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
checkpoint.Init()
err := checkpoint.GetLease(&par.ShardStatus{
ID: "0001",
Checkpoint: "",
LeaseTimeout: leaseTimeout,
Mux: &sync.RWMutex{},
}, "ijkl-mnop")
if err == nil || !errors.As(err, &ErrLeaseNotAcquired{}) {
t.Errorf("Got a lease when it was already claimed by ijkl-mnop: %s", err)
}
}
func TestFetchCheckpointWithStealing(t *testing.T) {
future := time.Now().AddDate(0, 1, 0)
svc := &mockDynamoDB{
tableExist: true,
item: map[string]*dynamodb.AttributeValue{
SequenceNumberKey: {S: aws.String("deadbeef")},
LeaseOwnerKey: {S: aws.String("abcd-efgh")},
LeaseTimeoutKey: {
S: aws.String(future.Format(time.RFC3339)),
},
},
}
kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc").
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
WithShardSyncIntervalMillis(5000).
WithFailoverTimeMillis(300000).
WithLeaseStealing(true)
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
checkpoint.Init()
status := &par.ShardStatus{
ID: "0001",
Checkpoint: "",
LeaseTimeout: time.Now(),
Mux: &sync.RWMutex{},
}
checkpoint.FetchCheckpoint(status)
leaseTimeout, _ := time.Parse(time.RFC3339, *svc.item[LeaseTimeoutKey].S)
assert.Equal(t, leaseTimeout, status.LeaseTimeout)
}
func TestGetLeaseConditional(t *testing.T) {
svc := &mockDynamoDB{tableExist: true, item: map[string]*dynamodb.AttributeValue{}}
kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc").
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
WithShardSyncIntervalMillis(5000).
WithFailoverTimeMillis(300000).
WithLeaseStealing(true)
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
checkpoint.Init()
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
LeaseKeyKey: {
S: aws.String("0001"),
},
LeaseOwnerKey: {
S: aws.String("abcd-efgh"),
},
LeaseTimeoutKey: {
S: aws.String(time.Now().Add(-1 * time.Second).UTC().Format(time.RFC3339)),
},
SequenceNumberKey: {
S: aws.String("deadbeef"),
},
ClaimRequestKey: {
S: aws.String("ijkl-mnop"),
},
}
input := &dynamodb.PutItemInput{
TableName: aws.String("TableName"),
Item: marshalledCheckpoint,
}
checkpoint.svc.PutItem(input)
shard := &par.ShardStatus{
ID: "0001",
Checkpoint: "deadbeef",
ClaimRequest: "ijkl-mnop",
Mux: &sync.RWMutex{},
}
err := checkpoint.FetchCheckpoint(shard)
if err != nil {
t.Errorf("Could not fetch checkpoint %s", err)
}
err = checkpoint.GetLease(shard, "ijkl-mnop")
if err != nil {
t.Errorf("Lease not aquired after timeout %s", err)
}
assert.Equal(t, *svc.expressionAttributeValues[":claim_request"].S, "ijkl-mnop")
assert.Contains(t, svc.conditionalExpression, " AND ClaimRequest = :claim_request")
}
type mockDynamoDB struct {
dynamodbiface.DynamoDBAPI
tableExist bool
item map[string]*dynamodb.AttributeValue
tableExist bool
item map[string]*dynamodb.AttributeValue
conditionalExpression string
expressionAttributeValues map[string]*dynamodb.AttributeValue
}
func (m *mockDynamoDB) ScanPages(*dynamodb.ScanInput, func(*dynamodb.ScanOutput, bool) bool) error {
return nil
}
func (m *mockDynamoDB) DescribeTable(*dynamodb.DescribeTableInput) (*dynamodb.DescribeTableOutput, error) {
@ -192,6 +404,16 @@ func (m *mockDynamoDB) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemO
m.item[ParentShardIdKey] = parent
}
if claimRequest, ok := item[ClaimRequestKey]; ok {
m.item[ClaimRequestKey] = claimRequest
}
if input.ConditionExpression != nil {
m.conditionalExpression = *input.ConditionExpression
}
m.expressionAttributeValues = input.ExpressionAttributeValues
return nil, nil
}
@ -214,3 +436,124 @@ func (m *mockDynamoDB) UpdateItem(input *dynamodb.UpdateItemInput) (*dynamodb.Up
func (m *mockDynamoDB) CreateTable(input *dynamodb.CreateTableInput) (*dynamodb.CreateTableOutput, error) {
return &dynamodb.CreateTableOutput{}, nil
}
func TestListActiveWorkers(t *testing.T) {
svc := &mockDynamoDB{tableExist: true, item: map[string]*dynamodb.AttributeValue{}}
kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc").
WithLeaseStealing(true)
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
err := checkpoint.Init()
if err != nil {
t.Errorf("Checkpoint initialization failed: %+v", err)
}
shardStatus := map[string]*par.ShardStatus{
"0000": {ID: "0000", AssignedTo: "worker_1", Checkpoint: "", Mux: &sync.RWMutex{}},
"0001": {ID: "0001", AssignedTo: "worker_2", Checkpoint: "", Mux: &sync.RWMutex{}},
"0002": {ID: "0002", AssignedTo: "worker_4", Checkpoint: "", Mux: &sync.RWMutex{}},
"0003": {ID: "0003", AssignedTo: "worker_0", Checkpoint: "", Mux: &sync.RWMutex{}},
"0004": {ID: "0004", AssignedTo: "worker_1", Checkpoint: "", Mux: &sync.RWMutex{}},
"0005": {ID: "0005", AssignedTo: "worker_3", Checkpoint: "", Mux: &sync.RWMutex{}},
"0006": {ID: "0006", AssignedTo: "worker_3", Checkpoint: "", Mux: &sync.RWMutex{}},
"0007": {ID: "0007", AssignedTo: "worker_0", Checkpoint: "", Mux: &sync.RWMutex{}},
"0008": {ID: "0008", AssignedTo: "worker_4", Checkpoint: "", Mux: &sync.RWMutex{}},
"0009": {ID: "0009", AssignedTo: "worker_2", Checkpoint: "", Mux: &sync.RWMutex{}},
"0010": {ID: "0010", AssignedTo: "worker_0", Checkpoint: ShardEnd, Mux: &sync.RWMutex{}},
}
workers, err := checkpoint.ListActiveWorkers(shardStatus)
if err != nil {
t.Error(err)
}
for workerID, shards := range workers {
assert.Equal(t, 2, len(shards))
for _, shard := range shards {
assert.Equal(t, workerID, shard.AssignedTo)
}
}
}
func TestListActiveWorkersErrShardNotAssigned(t *testing.T) {
svc := &mockDynamoDB{tableExist: true, item: map[string]*dynamodb.AttributeValue{}}
kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc").
WithLeaseStealing(true)
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
err := checkpoint.Init()
if err != nil {
t.Errorf("Checkpoint initialization failed: %+v", err)
}
shardStatus := map[string]*par.ShardStatus{
"0000": {ID: "0000", Mux: &sync.RWMutex{}},
}
_, err = checkpoint.ListActiveWorkers(shardStatus)
if err != ErrShardNotAssigned {
t.Error("Expected ErrShardNotAssigned when shard is missing AssignedTo value")
}
}
func TestClaimShard(t *testing.T) {
svc := &mockDynamoDB{tableExist: true, item: map[string]*dynamodb.AttributeValue{}}
kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc").
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
WithShardSyncIntervalMillis(5000).
WithFailoverTimeMillis(300000).
WithLeaseStealing(true)
checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc)
checkpoint.Init()
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
"ShardID": {
S: aws.String("0001"),
},
"AssignedTo": {
S: aws.String("abcd-efgh"),
},
"LeaseTimeout": {
S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)),
},
"Checkpoint": {
S: aws.String("deadbeef"),
},
}
input := &dynamodb.PutItemInput{
TableName: aws.String("TableName"),
Item: marshalledCheckpoint,
}
checkpoint.svc.PutItem(input)
shard := &par.ShardStatus{
ID: "0001",
Checkpoint: "deadbeef",
Mux: &sync.RWMutex{},
}
err := checkpoint.ClaimShard(shard, "ijkl-mnop")
if err != nil {
t.Errorf("Shard not claimed %s", err)
}
claimRequest, ok := svc.item[ClaimRequestKey]
if !ok {
t.Error("Expected claimRequest to be set by ClaimShard")
} else if *claimRequest.S != "ijkl-mnop" {
t.Errorf("Expected checkpoint to be ijkl-mnop. Got '%s'", *claimRequest.S)
}
status := &par.ShardStatus{
ID: shard.ID,
Mux: &sync.RWMutex{},
}
checkpoint.FetchCheckpoint(status)
// asiggnedTo, checkpointer, and parent shard id should be the same
assert.Equal(t, shard.AssignedTo, status.AssignedTo)
assert.Equal(t, shard.Checkpoint, status.Checkpoint)
assert.Equal(t, shard.ParentShardId, status.ParentShardId)
}

View file

@ -122,6 +122,18 @@ const (
// The amount of milliseconds to wait before graceful shutdown forcefully terminates.
DefaultShutdownGraceMillis = 5000
// Lease stealing defaults to false for backwards compatibility.
DefaultEnableLeaseStealing = false
// Interval between rebalance tasks defaults to 5 seconds.
DefaultLeaseStealingIntervalMillis = 5000
// Number of milliseconds to wait before another worker can aquire a claimed shard
DefaultLeaseStealingClaimTimeoutMillis = 120000
// Number of milliseconds to wait before syncing with lease table (dynamodDB)
DefaultLeaseSyncingIntervalMillis = 60000
)
type (
@ -257,6 +269,18 @@ type (
// MonitoringService publishes per worker-scoped metrics.
MonitoringService metrics.MonitoringService
// EnableLeaseStealing turns on lease stealing
EnableLeaseStealing bool
// LeaseStealingIntervalMillis The number of milliseconds between rebalance tasks
LeaseStealingIntervalMillis int
// LeaseStealingClaimTimeoutMillis The number of milliseconds to wait before another worker can aquire a claimed shard
LeaseStealingClaimTimeoutMillis int
// LeaseSyncingTimeInterval The number of milliseconds to wait before syncing with lease table (dynamoDB)
LeaseSyncingTimeIntervalMillis int
}
)

View file

@ -39,9 +39,35 @@ func TestConfig(t *testing.T) {
assert.Equal(t, "appName", kclConfig.ApplicationName)
assert.Equal(t, 500, kclConfig.FailoverTimeMillis)
assert.Equal(t, 10, kclConfig.TaskBackoffTimeMillis)
assert.True(t, kclConfig.EnableEnhancedFanOutConsumer)
assert.Equal(t, "fan-out-consumer", kclConfig.EnhancedFanOutConsumerName)
assert.Equal(t, false, kclConfig.EnableLeaseStealing)
assert.Equal(t, 5000, kclConfig.LeaseStealingIntervalMillis)
contextLogger := kclConfig.Logger.WithFields(logger.Fields{"key1": "value1"})
contextLogger.Debugf("Starting with default logger")
contextLogger.Infof("Default logger is awesome")
}
func TestConfigLeaseStealing(t *testing.T) {
kclConfig := NewKinesisClientLibConfig("appName", "StreamName", "us-west-2", "workerId").
WithFailoverTimeMillis(500).
WithMaxRecords(100).
WithInitialPositionInStream(TRIM_HORIZON).
WithIdleTimeBetweenReadsInMillis(20).
WithCallProcessRecordsEvenForEmptyRecordList(true).
WithTaskBackoffTimeMillis(10).
WithLeaseStealing(true).
WithLeaseStealingIntervalMillis(10000)
assert.Equal(t, "appName", kclConfig.ApplicationName)
assert.Equal(t, 500, kclConfig.FailoverTimeMillis)
assert.Equal(t, 10, kclConfig.TaskBackoffTimeMillis)
assert.Equal(t, true, kclConfig.EnableLeaseStealing)
assert.Equal(t, 10000, kclConfig.LeaseStealingIntervalMillis)
contextLogger := kclConfig.Logger.WithFields(logger.Fields{"key1": "value1"})
contextLogger.Debugf("Starting with default logger")
contextLogger.Infof("Default logger is awesome")

View file

@ -95,7 +95,11 @@ func NewKinesisClientLibConfigWithCredentials(applicationName, streamName, regio
InitialLeaseTableReadCapacity: DefaultInitialLeaseTableReadCapacity,
InitialLeaseTableWriteCapacity: DefaultInitialLeaseTableWriteCapacity,
SkipShardSyncAtWorkerInitializationIfLeasesExist: DefaultSkipShardSyncAtStartupIfLeasesExist,
Logger: logger.GetDefaultLogger(),
EnableLeaseStealing: DefaultEnableLeaseStealing,
LeaseStealingIntervalMillis: DefaultLeaseStealingIntervalMillis,
LeaseStealingClaimTimeoutMillis: DefaultLeaseStealingClaimTimeoutMillis,
LeaseSyncingTimeIntervalMillis: DefaultLeaseSyncingIntervalMillis,
Logger: logger.GetDefaultLogger(),
}
}
@ -241,3 +245,18 @@ func (c *KinesisClientLibConfiguration) WithEnhancedFanOutConsumerARN(consumerAR
c.EnableEnhancedFanOutConsumer = true
return c
}
func (c *KinesisClientLibConfiguration) WithLeaseStealing(enableLeaseStealing bool) *KinesisClientLibConfiguration {
c.EnableLeaseStealing = enableLeaseStealing
return c
}
func (c *KinesisClientLibConfiguration) WithLeaseStealingIntervalMillis(leaseStealingIntervalMillis int) *KinesisClientLibConfiguration {
c.LeaseStealingIntervalMillis = leaseStealingIntervalMillis
return c
}
func (c *KinesisClientLibConfiguration) WithLeaseSyncingIntervalMillis(leaseSyncingIntervalMillis int) *KinesisClientLibConfiguration {
c.LeaseSyncingTimeIntervalMillis = leaseSyncingIntervalMillis
return c
}

View file

@ -30,6 +30,8 @@ package worker
import (
"sync"
"time"
"github.com/vmware/vmware-go-kcl/clientlibrary/config"
)
type ShardStatus struct {
@ -43,6 +45,7 @@ type ShardStatus struct {
StartingSequenceNumber string
// child shard doesn't have end sequence number
EndingSequenceNumber string
ClaimRequest string
}
func (ss *ShardStatus) GetLeaseOwner() string {
@ -68,3 +71,24 @@ func (ss *ShardStatus) SetCheckpoint(c string) {
defer ss.Mux.Unlock()
ss.Checkpoint = c
}
func (ss *ShardStatus) GetLeaseTimeout() time.Time {
ss.Mux.Lock()
defer ss.Mux.Unlock()
return ss.LeaseTimeout
}
func (ss *ShardStatus) SetLeaseTimeout(timeout time.Time) {
ss.Mux.Lock()
defer ss.Mux.Unlock()
ss.LeaseTimeout = timeout
}
func (ss *ShardStatus) IsClaimRequestExpired(kclConfig *config.KinesisClientLibConfiguration) bool {
if leaseTimeout := ss.GetLeaseTimeout(); leaseTimeout.IsZero() {
return false
} else {
return leaseTimeout.
Before(time.Now().UTC().Add(time.Duration(-kclConfig.LeaseStealingClaimTimeoutMillis) * time.Millisecond))
}
}

View file

@ -103,7 +103,7 @@ func (sc *PollingShardConsumer) getRecords() error {
retriedErrors := 0
for {
if time.Now().UTC().After(sc.shard.LeaseTimeout.Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) {
log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID)
err = sc.checkpointer.GetLease(sc.shard, sc.consumerID)
if err != nil {

View file

@ -68,7 +68,8 @@ type Worker struct {
rng *rand.Rand
shardStatus map[string]*par.ShardStatus
shardStatus map[string]*par.ShardStatus
shardStealInProgress bool
}
// NewWorker constructs a Worker instance for processing Kinesis stream data.
@ -271,7 +272,7 @@ func (w *Worker) eventLoop() {
log.Infof("Found %d shards", foundShards)
}
// Count the number of leases hold by this worker excluding the processed shard
// Count the number of leases held by this worker excluding the processed shard
counter := 0
for _, shard := range w.shardStatus {
if shard.GetLeaseOwner() == w.workerID && shard.GetCheckpoint() != chk.ShardEnd {
@ -302,6 +303,20 @@ func (w *Worker) eventLoop() {
continue
}
var stealShard bool
if w.kclConfig.EnableLeaseStealing && shard.ClaimRequest != "" {
upcomingStealingInterval := time.Now().UTC().Add(time.Duration(w.kclConfig.LeaseStealingIntervalMillis) * time.Millisecond)
if shard.GetLeaseTimeout().Before(upcomingStealingInterval) && !shard.IsClaimRequestExpired(w.kclConfig) {
if shard.ClaimRequest == w.workerID {
stealShard = true
log.Debugf("Stealing shard: %s", shard.ID)
} else {
log.Debugf("Shard being stolen: %s", shard.ID)
continue
}
}
}
err = w.checkpointer.GetLease(shard, w.workerID)
if err != nil {
// cannot get lease on the shard
@ -311,6 +326,11 @@ func (w *Worker) eventLoop() {
continue
}
if stealShard {
log.Debugf("Successfully stole shard: %+v", shard.ID)
w.shardStealInProgress = false
}
// log metrics on got lease
w.mService.LeaseGained(shard.ID)
w.waitGroup.Add(1)
@ -325,6 +345,13 @@ func (w *Worker) eventLoop() {
}
}
if w.kclConfig.EnableLeaseStealing {
err = w.rebalance()
if err != nil {
log.Warnf("Error in rebalance: %+v", err)
}
}
select {
case <-*w.stop:
log.Infof("Shutting down...")
@ -335,6 +362,90 @@ func (w *Worker) eventLoop() {
}
}
func (w *Worker) rebalance() error {
log := w.kclConfig.Logger
workers, err := w.checkpointer.ListActiveWorkers(w.shardStatus)
if err != nil {
log.Debugf("Error listing workers. workerID: %s. Error: %+v ", w.workerID, err)
return err
}
// Only attempt to steal one shard at at time, to allow for linear convergence
if w.shardStealInProgress {
shardInfo := make(map[string]bool)
err := w.getShardIDs("", shardInfo)
if err != nil {
return err
}
for _, shard := range w.shardStatus {
if shard.ClaimRequest != "" && shard.ClaimRequest == w.workerID {
log.Debugf("Steal in progress. workerID: %s", w.workerID)
return nil
}
// Our shard steal was stomped on by a Checkpoint.
// We could deal with that, but instead just try again
w.shardStealInProgress = false
}
}
var numShards int
for _, shards := range workers {
numShards += len(shards)
}
numWorkers := len(workers)
// 1:1 shards to workers is optimal, so we cannot possibly rebalance
if numWorkers >= numShards {
log.Debugf("Optimal shard allocation, not stealing any shards. workerID: %s, %v > %v. ", w.workerID, numWorkers, numShards)
return nil
}
currentShards, ok := workers[w.workerID]
var numCurrentShards int
if !ok {
numCurrentShards = 0
numWorkers++
} else {
numCurrentShards = len(currentShards)
}
optimalShards := numShards / numWorkers
// We have more than or equal optimal shards, so no rebalancing can take place
if numCurrentShards >= optimalShards || numCurrentShards == w.kclConfig.MaxLeasesForWorker {
log.Debugf("We have enough shards, not attempting to steal any. workerID: %s", w.workerID)
return nil
}
maxShards := int(optimalShards)
var workerSteal string
for worker, shards := range workers {
if worker != w.workerID && len(shards) > maxShards {
workerSteal = worker
maxShards = len(shards)
}
}
// Not all shards are allocated so fallback to default shard allocation mechanisms
if workerSteal == "" {
log.Infof("Not all shards are allocated, not stealing any. workerID: %s", w.workerID)
return nil
}
// Steal a random shard from the worker with the most shards
w.shardStealInProgress = true
randIndex := rand.Intn(len(workers[workerSteal]))
shardToSteal := workers[workerSteal][randIndex]
log.Debugf("Stealing shard %s from %s", shardToSteal, workerSteal)
err = w.checkpointer.ClaimShard(w.shardStatus[shardToSteal.ID], w.workerID)
if err != nil {
w.shardStealInProgress = false
return err
}
return nil
}
// List all shards and store them into shardStatus table
// If shard has been removed, need to exclude it from cached shard status.
func (w *Worker) getShardIDs(nextToken string, shardInfo map[string]bool) error {

View file

@ -1,4 +1,4 @@
FROM golang:1.12
FROM golang:1.13
ENV PATH /go/bin:/src/bin:/root/go/bin:/usr/local/go/bin:$PATH
ENV GOPATH /go:/src
RUN go get -v github.com/alecthomas/gometalinter && \

View file

@ -0,0 +1,230 @@
package test
import (
"fmt"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
"github.com/stretchr/testify/assert"
chk "github.com/vmware/vmware-go-kcl/clientlibrary/checkpoint"
cfg "github.com/vmware/vmware-go-kcl/clientlibrary/config"
wk "github.com/vmware/vmware-go-kcl/clientlibrary/worker"
)
type LeaseStealingTest struct {
t *testing.T
config *TestClusterConfig
cluster *TestCluster
kc kinesisiface.KinesisAPI
dc dynamodbiface.DynamoDBAPI
backOffSeconds int
maxRetries int
}
func NewLeaseStealingTest(t *testing.T, config *TestClusterConfig, workerFactory TestWorkerFactory) *LeaseStealingTest {
cluster := NewTestCluster(t, config, workerFactory)
clientConfig := cluster.workerFactory.CreateKCLConfig("test-client", config)
return &LeaseStealingTest{
t: t,
config: config,
cluster: cluster,
kc: NewKinesisClient(t, config.regionName, clientConfig.KinesisEndpoint, clientConfig.KinesisCredentials),
dc: NewDynamoDBClient(t, config.regionName, clientConfig.DynamoDBEndpoint, clientConfig.KinesisCredentials),
backOffSeconds: 5,
maxRetries: 60,
}
}
func (lst *LeaseStealingTest) WithBackoffSeconds(backoff int) *LeaseStealingTest {
lst.backOffSeconds = backoff
return lst
}
func (lst *LeaseStealingTest) WithMaxRetries(retries int) *LeaseStealingTest {
lst.maxRetries = retries
return lst
}
func (lst *LeaseStealingTest) publishSomeData() (stop func()) {
done := make(chan int)
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
defer wg.Done()
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
lst.t.Log("Coninuously publishing records")
publishSomeData(lst.t, lst.kc)
}
}
}()
return func() {
close(done)
wg.Wait()
}
}
func (lst *LeaseStealingTest) getShardCountByWorker() map[string]int {
input := &dynamodb.ScanInput{
TableName: aws.String(lst.config.appName),
}
shardsByWorker := map[string]map[string]bool{}
err := lst.dc.ScanPages(input, func(out *dynamodb.ScanOutput, lastPage bool) bool {
for _, result := range out.Items {
if shardID, ok := result[chk.LeaseKeyKey]; !ok {
continue
} else if assignedTo, ok := result[chk.LeaseOwnerKey]; !ok {
continue
} else {
if _, ok := shardsByWorker[*assignedTo.S]; !ok {
shardsByWorker[*assignedTo.S] = map[string]bool{}
}
shardsByWorker[*assignedTo.S][*shardID.S] = true
}
}
return !lastPage
})
assert.Nil(lst.t, err)
shardCountByWorker := map[string]int{}
for worker, shards := range shardsByWorker {
shardCountByWorker[worker] = len(shards)
}
return shardCountByWorker
}
type LeaseStealingAssertions struct {
expectedLeasesForIntialWorker int
expectedLeasesPerWorker int
}
func (lst *LeaseStealingTest) Run(assertions LeaseStealingAssertions) {
// Publish records onto stream thoughtout the entire duration of the test
stop := lst.publishSomeData()
defer stop()
// Start worker 1
worker1, _ := lst.cluster.SpawnWorker()
// Wait until the above worker has all leases
var worker1ShardCount int
for i := 0; i < lst.maxRetries; i++ {
time.Sleep(time.Duration(lst.backOffSeconds) * time.Second)
shardCountByWorker := lst.getShardCountByWorker()
if shardCount, ok := shardCountByWorker[worker1]; ok && shardCount == assertions.expectedLeasesForIntialWorker {
worker1ShardCount = shardCount
break
}
}
// Assert correct number of leases
assert.Equal(lst.t, assertions.expectedLeasesForIntialWorker, worker1ShardCount)
// Spawn Remaining Wokers
for i := 0; i < lst.config.numWorkers-1; i++ {
lst.cluster.SpawnWorker()
}
// Wait For Rebalance
var shardCountByWorker map[string]int
for i := 0; i < lst.maxRetries; i++ {
time.Sleep(time.Duration(lst.backOffSeconds) * time.Second)
shardCountByWorker = lst.getShardCountByWorker()
correctCount := true
for _, count := range shardCountByWorker {
if count != assertions.expectedLeasesPerWorker {
correctCount = false
}
}
if correctCount {
break
}
}
// Assert Rebalanced
assert.Greater(lst.t, len(shardCountByWorker), 0)
for _, count := range shardCountByWorker {
assert.Equal(lst.t, assertions.expectedLeasesPerWorker, count)
}
// Shutdown Workers
time.Sleep(10 * time.Second)
lst.cluster.Shutdown()
}
type TestWorkerFactory interface {
CreateWorker(workerID string, kclConfig *cfg.KinesisClientLibConfiguration) *wk.Worker
CreateKCLConfig(workerID string, config *TestClusterConfig) *cfg.KinesisClientLibConfiguration
}
type TestClusterConfig struct {
numShards int
numWorkers int
appName string
streamName string
regionName string
workerIDTemplate string
}
type TestCluster struct {
t *testing.T
config *TestClusterConfig
workerFactory TestWorkerFactory
workerIDs []string
workers map[string]*wk.Worker
}
func NewTestCluster(t *testing.T, config *TestClusterConfig, workerFactory TestWorkerFactory) *TestCluster {
return &TestCluster{
t: t,
config: config,
workerFactory: workerFactory,
workerIDs: make([]string, 0),
workers: make(map[string]*wk.Worker),
}
}
func (tc *TestCluster) addWorker(workerID string, config *cfg.KinesisClientLibConfiguration) *wk.Worker {
worker := tc.workerFactory.CreateWorker(workerID, config)
tc.workerIDs = append(tc.workerIDs, workerID)
tc.workers[workerID] = worker
return worker
}
func (tc *TestCluster) SpawnWorker() (string, *wk.Worker) {
id := len(tc.workers)
workerID := fmt.Sprintf(tc.config.workerIDTemplate, id)
config := tc.workerFactory.CreateKCLConfig(workerID, tc.config)
worker := tc.addWorker(workerID, config)
err := worker.Start()
assert.Nil(tc.t, err)
return workerID, worker
}
func (tc *TestCluster) Shutdown() {
for workerID, worker := range tc.workers {
tc.t.Logf("Shutting down worker: %v", workerID)
worker.Shutdown()
}
}

View file

@ -23,9 +23,10 @@ package test
import (
"github.com/stretchr/testify/assert"
"testing"
"github.com/sirupsen/logrus"
"go.uber.org/zap"
"testing"
"github.com/vmware/vmware-go-kcl/logger"
zaplogger "github.com/vmware/vmware-go-kcl/logger/zap"

View file

@ -19,10 +19,11 @@
package test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"
kc "github.com/vmware/vmware-go-kcl/clientlibrary/interfaces"
"testing"
)
// Record processor factory is used to create RecordProcessor

View file

@ -21,9 +21,13 @@ package test
import (
"crypto/md5"
"fmt"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
rec "github.com/awslabs/kinesis-aggregation/go/records"
@ -50,12 +54,79 @@ func NewKinesisClient(t *testing.T, regionName, endpoint string, credentials *cr
return kinesis.New(s)
}
// NewDynamoDBClient to create a Kinesis Client.
func NewDynamoDBClient(t *testing.T, regionName, endpoint string, credentials *credentials.Credentials) *dynamodb.DynamoDB {
s, err := session.NewSession(&aws.Config{
Region: aws.String(regionName),
Endpoint: aws.String(endpoint),
Credentials: credentials,
})
if err != nil {
// no need to move forward
t.Fatalf("Failed in getting DynamoDB session for creating Worker: %+v", err)
}
return dynamodb.New(s)
}
func continuouslyPublishSomeData(t *testing.T, kc kinesisiface.KinesisAPI) func() {
shards := []*kinesis.Shard{}
var nextToken *string
for {
out, err := kc.ListShards(&kinesis.ListShardsInput{
StreamName: aws.String(streamName),
NextToken: nextToken,
})
if err != nil {
t.Errorf("Error in ListShards. %+v", err)
}
shards = append(shards, out.Shards...)
if out.NextToken == nil {
break
}
nextToken = out.NextToken
}
done := make(chan int)
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
ticker := time.NewTicker(500 * time.Millisecond)
for {
select {
case <-done:
return
case <-ticker.C:
publishToAllShards(t, kc, shards)
publishSomeData(t, kc)
}
}
}()
return func() {
close(done)
wg.Wait()
}
}
func publishToAllShards(t *testing.T, kc kinesisiface.KinesisAPI, shards []*kinesis.Shard) {
// Put records to all shards
for i := 0; i < 10; i++ {
for _, shard := range shards {
publishRecord(t, kc, shard.HashKeyRange.StartingHashKey)
}
}
}
// publishSomeData to put some records into Kinesis stream
func publishSomeData(t *testing.T, kc kinesisiface.KinesisAPI) {
// Put some data into stream.
t.Log("Putting data into stream using PutRecord API...")
for i := 0; i < 50; i++ {
publishRecord(t, kc)
publishRecord(t, kc, nil)
}
t.Log("Done putting data into stream using PutRecord API.")
@ -75,13 +146,17 @@ func publishSomeData(t *testing.T, kc kinesisiface.KinesisAPI) {
}
// publishRecord to put a record into Kinesis stream using PutRecord API.
func publishRecord(t *testing.T, kc kinesisiface.KinesisAPI) {
// Use random string as partition key to ensure even distribution across shards
_, err := kc.PutRecord(&kinesis.PutRecordInput{
func publishRecord(t *testing.T, kc kinesisiface.KinesisAPI, hashKey *string) {
input := &kinesis.PutRecordInput{
Data: []byte(specstr),
StreamName: aws.String(streamName),
PartitionKey: aws.String(utils.RandStringBytesMaskImpr(10)),
})
}
if hashKey != nil {
input.ExplicitHashKey = hashKey
}
// Use random string as partition key to ensure even distribution across shards
_, err := kc.PutRecord(input)
if err != nil {
t.Errorf("Error in PutRecord. %+v", err)
@ -94,10 +169,11 @@ func publishRecords(t *testing.T, kc kinesisiface.KinesisAPI) {
records := make([]*kinesis.PutRecordsRequestEntry, 5)
for i := 0; i < 5; i++ {
records[i] = &kinesis.PutRecordsRequestEntry{
record := &kinesis.PutRecordsRequestEntry{
Data: []byte(specstr),
PartitionKey: aws.String(utils.RandStringBytesMaskImpr(10)),
}
records[i] = record
}
_, err := kc.PutRecords(&kinesis.PutRecordsInput{

View file

@ -37,7 +37,7 @@ import (
)
func TestWorkerInjectCheckpointer(t *testing.T) {
kclConfig := cfg.NewKinesisClientLibConfig("appName", streamName, regionName, workerID).
kclConfig := cfg.NewKinesisClientLibConfig(appName, streamName, regionName, workerID).
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
@ -52,6 +52,12 @@ func TestWorkerInjectCheckpointer(t *testing.T) {
// configure cloudwatch as metrics system
kclConfig.WithMonitoringService(getMetricsConfig(kclConfig, metricsSystem))
// Put some data into stream.
kc := NewKinesisClient(t, regionName, kclConfig.KinesisEndpoint, kclConfig.KinesisCredentials)
// publishSomeData(t, kc)
stop := continuouslyPublishSomeData(t, kc)
defer stop()
// custom checkpointer or a mock checkpointer.
checkpointer := chk.NewDynamoCheckpoint(kclConfig)
@ -62,12 +68,8 @@ func TestWorkerInjectCheckpointer(t *testing.T) {
err := worker.Start()
assert.Nil(t, err)
// Put some data into stream.
kc := NewKinesisClient(t, regionName, kclConfig.KinesisEndpoint, kclConfig.KinesisCredentials)
publishSomeData(t, kc)
// wait a few seconds before shutdown processing
time.Sleep(10 * time.Second)
time.Sleep(30 * time.Second)
worker.Shutdown()
// verify the checkpointer after graceful shutdown
@ -86,7 +88,7 @@ func TestWorkerInjectCheckpointer(t *testing.T) {
}
func TestWorkerInjectKinesis(t *testing.T) {
kclConfig := cfg.NewKinesisClientLibConfig("appName", streamName, regionName, workerID).
kclConfig := cfg.NewKinesisClientLibConfig(appName, streamName, regionName, workerID).
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
@ -109,6 +111,11 @@ func TestWorkerInjectKinesis(t *testing.T) {
assert.Nil(t, err)
kc := kinesis.New(s)
// Put some data into stream.
// publishSomeData(t, kc)
stop := continuouslyPublishSomeData(t, kc)
defer stop()
// Inject a custom checkpointer into the worker.
worker := wk.NewWorker(recordProcessorFactory(t), kclConfig).
WithKinesis(kc)
@ -116,16 +123,13 @@ func TestWorkerInjectKinesis(t *testing.T) {
err = worker.Start()
assert.Nil(t, err)
// Put some data into stream.
publishSomeData(t, kc)
// wait a few seconds before shutdown processing
time.Sleep(10 * time.Second)
time.Sleep(30 * time.Second)
worker.Shutdown()
}
func TestWorkerInjectKinesisAndCheckpointer(t *testing.T) {
kclConfig := cfg.NewKinesisClientLibConfig("appName", streamName, regionName, workerID).
kclConfig := cfg.NewKinesisClientLibConfig(appName, streamName, regionName, workerID).
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithMaxLeasesForWorker(1).
@ -148,6 +152,11 @@ func TestWorkerInjectKinesisAndCheckpointer(t *testing.T) {
assert.Nil(t, err)
kc := kinesis.New(s)
// Put some data into stream.
// publishSomeData(t, kc)
stop := continuouslyPublishSomeData(t, kc)
defer stop()
// custom checkpointer or a mock checkpointer.
checkpointer := chk.NewDynamoCheckpoint(kclConfig)
@ -159,10 +168,7 @@ func TestWorkerInjectKinesisAndCheckpointer(t *testing.T) {
err = worker.Start()
assert.Nil(t, err)
// Put some data into stream.
publishSomeData(t, kc)
// wait a few seconds before shutdown processing
time.Sleep(10 * time.Second)
time.Sleep(30 * time.Second)
worker.Shutdown()
}

View file

@ -0,0 +1,127 @@
package test
import (
"testing"
chk "github.com/vmware/vmware-go-kcl/clientlibrary/checkpoint"
cfg "github.com/vmware/vmware-go-kcl/clientlibrary/config"
wk "github.com/vmware/vmware-go-kcl/clientlibrary/worker"
"github.com/vmware/vmware-go-kcl/logger"
)
func TestLeaseStealing(t *testing.T) {
config := &TestClusterConfig{
numShards: 4,
numWorkers: 2,
appName: appName,
streamName: streamName,
regionName: regionName,
workerIDTemplate: workerID + "-%v",
}
test := NewLeaseStealingTest(t, config, newLeaseStealingWorkerFactory(t))
test.Run(LeaseStealingAssertions{
expectedLeasesForIntialWorker: config.numShards,
expectedLeasesPerWorker: config.numShards / config.numWorkers,
})
}
type leaseStealingWorkerFactory struct {
t *testing.T
}
func newLeaseStealingWorkerFactory(t *testing.T) *leaseStealingWorkerFactory {
return &leaseStealingWorkerFactory{t}
}
func (wf *leaseStealingWorkerFactory) CreateKCLConfig(workerID string, config *TestClusterConfig) *cfg.KinesisClientLibConfiguration {
log := logger.NewLogrusLoggerWithConfig(logger.Configuration{
EnableConsole: true,
ConsoleLevel: logger.Error,
ConsoleJSONFormat: false,
EnableFile: true,
FileLevel: logger.Info,
FileJSONFormat: true,
Filename: "log.log",
})
log.WithFields(logger.Fields{"worker": workerID})
return cfg.NewKinesisClientLibConfig(config.appName, config.streamName, config.regionName, workerID).
WithInitialPositionInStream(cfg.LATEST).
WithMaxRecords(10).
WithShardSyncIntervalMillis(5000).
WithFailoverTimeMillis(10000).
WithLeaseStealing(true).
WithLogger(log)
}
func (wf *leaseStealingWorkerFactory) CreateWorker(workerID string, kclConfig *cfg.KinesisClientLibConfiguration) *wk.Worker {
worker := wk.NewWorker(recordProcessorFactory(wf.t), kclConfig)
return worker
}
func TestLeaseStealingInjectCheckpointer(t *testing.T) {
config := &TestClusterConfig{
numShards: 4,
numWorkers: 2,
appName: appName,
streamName: streamName,
regionName: regionName,
workerIDTemplate: workerID + "-%v",
}
test := NewLeaseStealingTest(t, config, newleaseStealingWorkerFactoryCustomChk(t))
test.Run(LeaseStealingAssertions{
expectedLeasesForIntialWorker: config.numShards,
expectedLeasesPerWorker: config.numShards / config.numWorkers,
})
}
type leaseStealingWorkerFactoryCustom struct {
*leaseStealingWorkerFactory
}
func newleaseStealingWorkerFactoryCustomChk(t *testing.T) *leaseStealingWorkerFactoryCustom {
return &leaseStealingWorkerFactoryCustom{
newLeaseStealingWorkerFactory(t),
}
}
func (wfc *leaseStealingWorkerFactoryCustom) CreateWorker(workerID string, kclConfig *cfg.KinesisClientLibConfiguration) *wk.Worker {
worker := wfc.leaseStealingWorkerFactory.CreateWorker(workerID, kclConfig)
checkpointer := chk.NewDynamoCheckpoint(kclConfig)
return worker.WithCheckpointer(checkpointer)
}
func TestLeaseStealingWithMaxLeasesForWorker(t *testing.T) {
config := &TestClusterConfig{
numShards: 4,
numWorkers: 2,
appName: appName,
streamName: streamName,
regionName: regionName,
workerIDTemplate: workerID + "-%v",
}
test := NewLeaseStealingTest(t, config, newleaseStealingWorkerFactoryMaxLeases(t, config.numShards-1))
test.Run(LeaseStealingAssertions{
expectedLeasesForIntialWorker: config.numShards - 1,
expectedLeasesPerWorker: 2,
})
}
type leaseStealingWorkerFactoryMaxLeases struct {
maxLeases int
*leaseStealingWorkerFactory
}
func newleaseStealingWorkerFactoryMaxLeases(t *testing.T, maxLeases int) *leaseStealingWorkerFactoryMaxLeases {
return &leaseStealingWorkerFactoryMaxLeases{
maxLeases,
newLeaseStealingWorkerFactory(t),
}
}
func (wfm *leaseStealingWorkerFactoryMaxLeases) CreateKCLConfig(workerID string, config *TestClusterConfig) *cfg.KinesisClientLibConfiguration {
kclConfig := wfm.leaseStealingWorkerFactory.CreateKCLConfig(workerID, config)
kclConfig.WithMaxLeasesForWorker(wfm.maxLeases)
return kclConfig
}

View file

@ -60,7 +60,7 @@ func TestWorker(t *testing.T) {
// In order to have precise control over logging. Use logger with config
config := logger.Configuration{
EnableConsole: true,
ConsoleLevel: logger.Debug,
ConsoleLevel: logger.Error,
ConsoleJSONFormat: false,
EnableFile: true,
FileLevel: logger.Info,
@ -269,8 +269,13 @@ func runTest(kclConfig *cfg.KinesisClientLibConfiguration, triggersig bool, t *t
// configure cloudwatch as metrics system
kclConfig.WithMonitoringService(getMetricsConfig(kclConfig, metricsSystem))
worker := wk.NewWorker(recordProcessorFactory(t), kclConfig)
// Put some data into stream.
kc := NewKinesisClient(t, regionName, kclConfig.KinesisEndpoint, kclConfig.KinesisCredentials)
// publishSomeData(t, kc)
stop := continuouslyPublishSomeData(t, kc)
defer stop()
worker := wk.NewWorker(recordProcessorFactory(t), kclConfig)
err := worker.Start()
assert.Nil(t, err)
@ -286,10 +291,6 @@ func runTest(kclConfig *cfg.KinesisClientLibConfiguration, triggersig bool, t *t
//os.Exit(0)
}()
// Put some data into stream.
kc := NewKinesisClient(t, regionName, kclConfig.KinesisEndpoint, kclConfig.KinesisCredentials)
publishSomeData(t, kc)
if triggersig {
t.Log("Trigger signal SIGINT")
p, _ := os.FindProcess(os.Getpid())
@ -297,7 +298,7 @@ func runTest(kclConfig *cfg.KinesisClientLibConfiguration, triggersig bool, t *t
}
// wait a few seconds before shutdown processing
time.Sleep(10 * time.Second)
time.Sleep(30 * time.Second)
if metricsSystem == "prometheus" {
res, err := http.Get("http://localhost:8080/metrics")