Introduce Group interface and AllGroup

As we move towards consumer groups we'll need to support the current
"consume all shards" strategy, and setup the codebase for the
anticipated "consume balanced shards."
This commit is contained in:
Harlow Ward 2019-05-24 19:43:36 -07:00
parent 9cd2e57ba4
commit bd42663013
9 changed files with 170 additions and 139 deletions

97
allgroup.go Normal file
View file

@ -0,0 +1,97 @@
package consumer
import (
"context"
"sync"
"time"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
func NewAllGroup(ksis kinesisiface.KinesisAPI, ck Checkpoint, streamName string, logger Logger) *AllGroup {
return &AllGroup{
ksis: ksis,
shards: make(map[string]*kinesis.Shard),
streamName: streamName,
logger: logger,
checkpoint: ck,
}
}
// AllGroup caches a local list of the shards we are already processing
// and routinely polls the stream looking for new shards to process
type AllGroup struct {
ksis kinesisiface.KinesisAPI
streamName string
logger Logger
checkpoint Checkpoint
shardMu sync.Mutex
shards map[string]*kinesis.Shard
}
// start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context) chan *kinesis.Shard {
var (
shardc = make(chan *kinesis.Shard, 1)
ticker = time.NewTicker(30 * time.Second)
)
g.findNewShards(shardc)
// Note: while ticker is a rather naive approach to this problem,
// it actually simplies a few things. i.e. If we miss a new shard while
// AWS is resharding we'll pick it up max 30 seconds later.
// It might be worth refactoring this flow to allow the consumer to
// to notify the broker when a shard is closed. However, shards don't
// necessarily close at the same time, so we could potentially get a
// thundering heard of notifications from the consumer.
go func() {
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
g.findNewShards(shardc)
}
}
}()
return shardc
}
func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) {
return g.checkpoint.Get(streamName, shardID)
}
func (g *AllGroup) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
return g.checkpoint.Set(streamName, shardID, sequenceNumber)
}
// findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing
// a particular shard.
func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) {
g.shardMu.Lock()
defer g.shardMu.Unlock()
g.logger.Log("[GROUP]", "fetching shards")
shards, err := listShards(g.ksis, g.streamName)
if err != nil {
g.logger.Log("[GROUP]", err)
return
}
for _, shard := range shards {
if _, ok := g.shards[*shard.ShardId]; ok {
continue
}
g.shards[*shard.ShardId] = shard
shardc <- shard
}
}

114
broker.go
View file

@ -1,114 +0,0 @@
package consumer
import (
"context"
"fmt"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
func newBroker(
client kinesisiface.KinesisAPI,
streamName string,
shardc chan *kinesis.Shard,
logger Logger,
) *broker {
return &broker{
client: client,
shards: make(map[string]*kinesis.Shard),
streamName: streamName,
shardc: shardc,
logger: logger,
}
}
// broker caches a local list of the shards we are already processing
// and routinely polls the stream looking for new shards to process
type broker struct {
client kinesisiface.KinesisAPI
streamName string
shardc chan *kinesis.Shard
logger Logger
shardMu sync.Mutex
shards map[string]*kinesis.Shard
}
// start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence.
func (b *broker) start(ctx context.Context) {
b.findNewShards()
ticker := time.NewTicker(30 * time.Second)
// Note: while ticker is a rather naive approach to this problem,
// it actually simplies a few things. i.e. If we miss a new shard while
// AWS is resharding we'll pick it up max 30 seconds later.
// It might be worth refactoring this flow to allow the consumer to
// to notify the broker when a shard is closed. However, shards don't
// necessarily close at the same time, so we could potentially get a
// thundering heard of notifications from the consumer.
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
b.findNewShards()
}
}
}
// findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing
// a particular shard.
func (b *broker) findNewShards() {
b.shardMu.Lock()
defer b.shardMu.Unlock()
b.logger.Log("[BROKER]", "fetching shards")
shards, err := b.listShards()
if err != nil {
b.logger.Log("[BROKER]", err)
return
}
for _, shard := range shards {
if _, ok := b.shards[*shard.ShardId]; ok {
continue
}
b.shards[*shard.ShardId] = shard
b.shardc <- shard
}
}
// listShards pulls a list of shard IDs from the kinesis api
func (b *broker) listShards() ([]*kinesis.Shard, error) {
var ss []*kinesis.Shard
var listShardsInput = &kinesis.ListShardsInput{
StreamName: aws.String(b.streamName),
}
for {
resp, err := b.client.ListShards(listShardsInput)
if err != nil {
return nil, fmt.Errorf("ListShards error: %v", err)
}
ss = append(ss, resp.Shards...)
if resp.NextToken == nil {
return ss, nil
}
listShardsInput = &kinesis.ListShardsInput{
NextToken: resp.NextToken,
StreamName: aws.String(b.streamName),
}
}
}

View file

@ -23,11 +23,10 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
return nil, fmt.Errorf("must provide stream name") return nil, fmt.Errorf("must provide stream name")
} }
// new consumer with no-op checkpoint, counter, and logger // new consumer with noop group, counter, and logger
c := &Consumer{ c := &Consumer{
streamName: streamName, streamName: streamName,
initialShardIteratorType: kinesis.ShardIteratorTypeLatest, initialShardIteratorType: kinesis.ShardIteratorTypeLatest,
checkpoint: &noopCheckpoint{},
counter: &noopCounter{}, counter: &noopCounter{},
logger: &noopLogger{ logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags), logger: log.New(ioutil.Discard, "", log.LstdFlags),
@ -48,6 +47,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
c.client = kinesis.New(newSession) c.client = kinesis.New(newSession)
} }
// default group if none provided
if c.group == nil {
c.group = NewAllGroup(c.client, c.checkpoint, c.streamName, c.logger)
}
return c, nil return c, nil
} }
@ -57,6 +61,7 @@ type Consumer struct {
initialShardIteratorType string initialShardIteratorType string
client kinesisiface.KinesisAPI client kinesisiface.KinesisAPI
logger Logger logger Logger
group Group
checkpoint Checkpoint checkpoint Checkpoint
counter Counter counter Counter
} }
@ -64,7 +69,6 @@ type Consumer struct {
// ScanFunc is the type of the function called for each message read // ScanFunc is the type of the function called for each message read
// from the stream. The record argument contains the original record // from the stream. The record argument contains the original record
// returned from the AWS Kinesis library. // returned from the AWS Kinesis library.
//
// If an error is returned, scanning stops. The sole exception is when the // If an error is returned, scanning stops. The sole exception is when the
// function returns the special value SkipCheckpoint. // function returns the special value SkipCheckpoint.
type ScanFunc func(*Record) error type ScanFunc func(*Record) error
@ -78,16 +82,13 @@ var SkipCheckpoint = errors.New("skip checkpoint")
// is passed through to each of the goroutines and called with each message pulled from // is passed through to each of the goroutines and called with each message pulled from
// the stream. // the stream.
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
var (
errc = make(chan error, 1)
shardc = make(chan *kinesis.Shard, 1)
broker = newBroker(c.client, c.streamName, shardc, c.logger)
)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
go broker.start(ctx) var (
errc = make(chan error, 1)
shardc = c.group.Start(ctx)
)
go func() { go func() {
<-ctx.Done() <-ctx.Done()
@ -110,7 +111,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
} }
close(errc) close(errc)
return <-errc return <-errc
} }
@ -118,7 +118,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
// get last seq number from checkpoint // get last seq number from checkpoint
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID)
if err != nil { if err != nil {
return fmt.Errorf("get checkpoint error: %v", err) return fmt.Errorf("get checkpoint error: %v", err)
} }
@ -164,7 +164,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
} }
if err != SkipCheckpoint { if err != SkipCheckpoint {
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
return err return err
} }
} }

View file

@ -19,7 +19,7 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis"
consumer "github.com/harlow/kinesis-consumer" consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb" storage "github.com/harlow/kinesis-consumer/checkpoint/ddb"
) )
// kick off a server for exposing scan metrics // kick off a server for exposing scan metrics
@ -69,8 +69,8 @@ func main() {
myKsis := kinesis.New(sess) myKsis := kinesis.New(sess)
myDdbClient := dynamodb.New(sess) myDdbClient := dynamodb.New(sess)
// ddb checkpoint // ddb persitance
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDdbClient), checkpoint.WithRetryer(&MyRetryer{})) ddb, err := storage.New(*app, *table, storage.WithDynamoClient(myDdbClient), storage.WithRetryer(&MyRetryer{}))
if err != nil { if err != nil {
log.Log("checkpoint error: %v", err) log.Log("checkpoint error: %v", err)
} }
@ -81,7 +81,7 @@ func main() {
// consumer // consumer
c, err := consumer.New( c, err := consumer.New(
*stream, *stream,
consumer.WithCheckpoint(ck), consumer.WithStorage(ddb),
consumer.WithLogger(log), consumer.WithLogger(log),
consumer.WithCounter(counter), consumer.WithCounter(counter),
consumer.WithClient(myKsis), consumer.WithClient(myKsis),
@ -111,17 +111,17 @@ func main() {
log.Log("scan error: %v", err) log.Log("scan error: %v", err)
} }
if err := ck.Shutdown(); err != nil { if err := ddb.Shutdown(); err != nil {
log.Log("checkpoint shutdown error: %v", err) log.Log("storage shutdown error: %v", err)
} }
} }
// MyRetryer used for checkpointing // MyRetryer used for storage
type MyRetryer struct { type MyRetryer struct {
checkpoint.Retryer storage.Retryer
} }
// ShouldRetry implements custom logic for when a checkpont should retry // ShouldRetry implements custom logic for when errors should retry
func (r *MyRetryer) ShouldRetry(err error) bool { func (r *MyRetryer) ShouldRetry(err error) bool {
if awsErr, ok := err.(awserr.Error); ok { if awsErr, ok := err.(awserr.Error); ok {
switch awsErr.Code() { switch awsErr.Code() {

View file

@ -33,7 +33,7 @@ func main() {
// consumer // consumer
c, err := consumer.New( c, err := consumer.New(
*stream, *stream,
consumer.WithCheckpoint(ck), consumer.WithStorage(ck),
consumer.WithCounter(counter), consumer.WithCounter(counter),
) )
if err != nil { if err != nil {

View file

@ -33,7 +33,7 @@ func main() {
// consumer // consumer
c, err := consumer.New( c, err := consumer.New(
*stream, *stream,
consumer.WithCheckpoint(ck), consumer.WithStorage(ck),
consumer.WithCounter(counter), consumer.WithCounter(counter),
) )
if err != nil { if err != nil {

View file

@ -43,7 +43,7 @@ func main() {
// consumer // consumer
c, err := consumer.New( c, err := consumer.New(
*stream, *stream,
consumer.WithCheckpoint(ck), consumer.WithStorage(ck),
consumer.WithLogger(logger), consumer.WithLogger(logger),
) )
if err != nil { if err != nil {

14
group.go Normal file
View file

@ -0,0 +1,14 @@
package consumer
import (
"context"
"github.com/aws/aws-sdk-go/service/kinesis"
)
// Group interface used to manage which shard to process
type Group interface {
Start(ctx context.Context) chan *kinesis.Shard
GetCheckpoint(streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error
}

34
kinesis.go Normal file
View file

@ -0,0 +1,34 @@
package consumer
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
// listShards pulls a list of shard IDs from the kinesis api
func listShards(ksis kinesisiface.KinesisAPI, streamName string) ([]*kinesis.Shard, error) {
var ss []*kinesis.Shard
var listShardsInput = &kinesis.ListShardsInput{
StreamName: aws.String(streamName),
}
for {
resp, err := ksis.ListShards(listShardsInput)
if err != nil {
return nil, fmt.Errorf("ListShards error: %v", err)
}
ss = append(ss, resp.Shards...)
if resp.NextToken == nil {
return ss, nil
}
listShardsInput = &kinesis.ListShardsInput{
NextToken: resp.NextToken,
StreamName: aws.String(streamName),
}
}
}