From 7018c0c47e2ec7c46e1c5c090071f5c40376cac1 Mon Sep 17 00:00:00 2001 From: Harlow Ward Date: Sun, 9 Jun 2019 13:42:25 -0700 Subject: [PATCH] Introduce Group interface and AllGroup (#91) * Introduce Group interface and AllGroup As we move towards consumer groups we'll need to support the current "consume all shards" strategy, and setup the codebase for the anticipated "consume balanced shards." --- allgroup.go | 90 ++++++++++++++++++++++++++++ broker.go | 114 ------------------------------------ consumer.go | 32 +++++----- examples/producer/README.md | 2 +- group.go | 14 +++++ kinesis.go | 34 +++++++++++ 6 files changed, 156 insertions(+), 130 deletions(-) create mode 100644 allgroup.go delete mode 100644 broker.go create mode 100644 group.go create mode 100644 kinesis.go diff --git a/allgroup.go b/allgroup.go new file mode 100644 index 0000000..a8330e9 --- /dev/null +++ b/allgroup.go @@ -0,0 +1,90 @@ +package consumer + +import ( + "context" + "sync" + "time" + + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" +) + +func NewAllGroup(ksis kinesisiface.KinesisAPI, ck Checkpoint, streamName string, logger Logger) *AllGroup { + return &AllGroup{ + ksis: ksis, + shards: make(map[string]*kinesis.Shard), + streamName: streamName, + logger: logger, + checkpoint: ck, + } +} + +// AllGroup caches a local list of the shards we are already processing +// and routinely polls the stream looking for new shards to process +type AllGroup struct { + ksis kinesisiface.KinesisAPI + streamName string + logger Logger + checkpoint Checkpoint + + shardMu sync.Mutex + shards map[string]*kinesis.Shard +} + +// start is a blocking operation which will loop and attempt to find new +// shards on a regular cadence. +func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) { + var ticker = time.NewTicker(30 * time.Second) + g.findNewShards(shardc) + + // Note: while ticker is a rather naive approach to this problem, + // it actually simplies a few things. i.e. If we miss a new shard while + // AWS is resharding we'll pick it up max 30 seconds later. + + // It might be worth refactoring this flow to allow the consumer to + // to notify the broker when a shard is closed. However, shards don't + // necessarily close at the same time, so we could potentially get a + // thundering heard of notifications from the consumer. + + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + g.findNewShards(shardc) + } + } +} + +func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) { + return g.checkpoint.Get(streamName, shardID) +} + +func (g *AllGroup) SetCheckpoint(streamName, shardID, sequenceNumber string) error { + return g.checkpoint.Set(streamName, shardID, sequenceNumber) +} + +// findNewShards pulls the list of shards from the Kinesis API +// and uses a local cache to determine if we are already processing +// a particular shard. +func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) { + g.shardMu.Lock() + defer g.shardMu.Unlock() + + g.logger.Log("[GROUP]", "fetching shards") + + shards, err := listShards(g.ksis, g.streamName) + if err != nil { + g.logger.Log("[GROUP] error:", err) + return + } + + for _, shard := range shards { + if _, ok := g.shards[*shard.ShardId]; ok { + continue + } + g.shards[*shard.ShardId] = shard + shardc <- shard + } +} diff --git a/broker.go b/broker.go deleted file mode 100644 index ecf25a1..0000000 --- a/broker.go +++ /dev/null @@ -1,114 +0,0 @@ -package consumer - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" -) - -func newBroker( - client kinesisiface.KinesisAPI, - streamName string, - shardc chan *kinesis.Shard, - logger Logger, -) *broker { - return &broker{ - client: client, - shards: make(map[string]*kinesis.Shard), - streamName: streamName, - shardc: shardc, - logger: logger, - } -} - -// broker caches a local list of the shards we are already processing -// and routinely polls the stream looking for new shards to process -type broker struct { - client kinesisiface.KinesisAPI - streamName string - shardc chan *kinesis.Shard - logger Logger - - shardMu sync.Mutex - shards map[string]*kinesis.Shard -} - -// start is a blocking operation which will loop and attempt to find new -// shards on a regular cadence. -func (b *broker) start(ctx context.Context) { - b.findNewShards() - ticker := time.NewTicker(30 * time.Second) - - // Note: while ticker is a rather naive approach to this problem, - // it actually simplies a few things. i.e. If we miss a new shard while - // AWS is resharding we'll pick it up max 30 seconds later. - - // It might be worth refactoring this flow to allow the consumer to - // to notify the broker when a shard is closed. However, shards don't - // necessarily close at the same time, so we could potentially get a - // thundering heard of notifications from the consumer. - - for { - select { - case <-ctx.Done(): - ticker.Stop() - return - case <-ticker.C: - b.findNewShards() - } - } -} - -// findNewShards pulls the list of shards from the Kinesis API -// and uses a local cache to determine if we are already processing -// a particular shard. -func (b *broker) findNewShards() { - b.shardMu.Lock() - defer b.shardMu.Unlock() - - b.logger.Log("[BROKER]", "fetching shards") - - shards, err := b.listShards() - if err != nil { - b.logger.Log("[BROKER]", err) - return - } - - for _, shard := range shards { - if _, ok := b.shards[*shard.ShardId]; ok { - continue - } - b.shards[*shard.ShardId] = shard - b.shardc <- shard - } -} - -// listShards pulls a list of shard IDs from the kinesis api -func (b *broker) listShards() ([]*kinesis.Shard, error) { - var ss []*kinesis.Shard - var listShardsInput = &kinesis.ListShardsInput{ - StreamName: aws.String(b.streamName), - } - - for { - resp, err := b.client.ListShards(listShardsInput) - if err != nil { - return nil, fmt.Errorf("ListShards error: %v", err) - } - ss = append(ss, resp.Shards...) - - if resp.NextToken == nil { - return ss, nil - } - - listShardsInput = &kinesis.ListShardsInput{ - NextToken: resp.NextToken, - StreamName: aws.String(b.streamName), - } - } -} diff --git a/consumer.go b/consumer.go index e9c583a..1ea5c5c 100644 --- a/consumer.go +++ b/consumer.go @@ -27,8 +27,8 @@ func New(streamName string, opts ...Option) (*Consumer, error) { c := &Consumer{ streamName: streamName, initialShardIteratorType: kinesis.ShardIteratorTypeLatest, - checkpoint: &noopCheckpoint{}, counter: &noopCounter{}, + checkpoint: &noopCheckpoint{}, logger: &noopLogger{ logger: log.New(ioutil.Discard, "", log.LstdFlags), }, @@ -48,6 +48,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) { c.client = kinesis.New(newSession) } + // default group if none provided + if c.group == nil { + c.group = NewAllGroup(c.client, c.checkpoint, c.streamName, c.logger) + } + return c, nil } @@ -57,6 +62,7 @@ type Consumer struct { initialShardIteratorType string client kinesisiface.KinesisAPI logger Logger + group Group checkpoint Checkpoint counter Counter } @@ -64,7 +70,6 @@ type Consumer struct { // ScanFunc is the type of the function called for each message read // from the stream. The record argument contains the original record // returned from the AWS Kinesis library. -// // If an error is returned, scanning stops. The sole exception is when the // function returns the special value SkipCheckpoint. type ScanFunc func(*Record) error @@ -78,18 +83,16 @@ var SkipCheckpoint = errors.New("skip checkpoint") // is passed through to each of the goroutines and called with each message pulled from // the stream. func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { - var ( - errc = make(chan error, 1) - shardc = make(chan *kinesis.Shard, 1) - broker = newBroker(c.client, c.streamName, shardc, c.logger) - ) - ctx, cancel := context.WithCancel(ctx) defer cancel() - go broker.start(ctx) + var ( + errc = make(chan error, 1) + shardc = make(chan *kinesis.Shard, 1) + ) go func() { + c.group.Start(ctx, shardc) <-ctx.Done() close(shardc) }() @@ -110,7 +113,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { } close(errc) - return <-errc } @@ -118,7 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { // get last seq number from checkpoint - lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) + lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID) if err != nil { return fmt.Errorf("get checkpoint error: %v", err) } @@ -129,9 +131,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return fmt.Errorf("get shard iterator error: %v", err) } - c.logger.Log("[START]\t", shardID, lastSeqNum) + c.logger.Log("[CONSUMER] start scan:", shardID, lastSeqNum) defer func() { - c.logger.Log("[STOP]\t", shardID) + c.logger.Log("[CONSUMER] stop scan:", shardID) }() for { @@ -164,7 +166,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } if err != SkipCheckpoint { - if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { + if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { return err } } @@ -175,7 +177,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } if isShardClosed(resp.NextShardIterator, shardIterator) { - c.logger.Log("[CLOSED]\t", shardID) + c.logger.Log("[CONSUMER] shard closed:", shardID) return nil } diff --git a/examples/producer/README.md b/examples/producer/README.md index a620e95..da7c13b 100644 --- a/examples/producer/README.md +++ b/examples/producer/README.md @@ -8,7 +8,7 @@ Export the required environment vars for connecting to the Kinesis stream: ``` export AWS_PROFILE= -export AWS_REGION_NAME= +export AWS_REGION= ``` ### Running the code diff --git a/group.go b/group.go new file mode 100644 index 0000000..aa08438 --- /dev/null +++ b/group.go @@ -0,0 +1,14 @@ +package consumer + +import ( + "context" + + "github.com/aws/aws-sdk-go/service/kinesis" +) + +// Group interface used to manage which shard to process +type Group interface { + Start(ctx context.Context, shardc chan *kinesis.Shard) + GetCheckpoint(streamName, shardID string) (string, error) + SetCheckpoint(streamName, shardID, sequenceNumber string) error +} diff --git a/kinesis.go b/kinesis.go new file mode 100644 index 0000000..490b73b --- /dev/null +++ b/kinesis.go @@ -0,0 +1,34 @@ +package consumer + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" +) + +// listShards pulls a list of shard IDs from the kinesis api +func listShards(ksis kinesisiface.KinesisAPI, streamName string) ([]*kinesis.Shard, error) { + var ss []*kinesis.Shard + var listShardsInput = &kinesis.ListShardsInput{ + StreamName: aws.String(streamName), + } + + for { + resp, err := ksis.ListShards(listShardsInput) + if err != nil { + return nil, fmt.Errorf("ListShards error: %v", err) + } + ss = append(ss, resp.Shards...) + + if resp.NextToken == nil { + return ss, nil + } + + listShardsInput = &kinesis.ListShardsInput{ + NextToken: resp.NextToken, + StreamName: aws.String(streamName), + } + } +}