Compare commits
4 commits
master
...
hw-group-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0328cba5c9 | ||
|
|
4fd29c54ff | ||
|
|
2ab5ec4031 | ||
|
|
bd42663013 |
6 changed files with 156 additions and 130 deletions
90
allgroup.go
Normal file
90
allgroup.go
Normal file
|
|
@ -0,0 +1,90 @@
|
||||||
|
package consumer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewAllGroup(ksis kinesisiface.KinesisAPI, ck Checkpoint, streamName string, logger Logger) *AllGroup {
|
||||||
|
return &AllGroup{
|
||||||
|
ksis: ksis,
|
||||||
|
shards: make(map[string]*kinesis.Shard),
|
||||||
|
streamName: streamName,
|
||||||
|
logger: logger,
|
||||||
|
checkpoint: ck,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllGroup caches a local list of the shards we are already processing
|
||||||
|
// and routinely polls the stream looking for new shards to process
|
||||||
|
type AllGroup struct {
|
||||||
|
ksis kinesisiface.KinesisAPI
|
||||||
|
streamName string
|
||||||
|
logger Logger
|
||||||
|
checkpoint Checkpoint
|
||||||
|
|
||||||
|
shardMu sync.Mutex
|
||||||
|
shards map[string]*kinesis.Shard
|
||||||
|
}
|
||||||
|
|
||||||
|
// start is a blocking operation which will loop and attempt to find new
|
||||||
|
// shards on a regular cadence.
|
||||||
|
func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) {
|
||||||
|
var ticker = time.NewTicker(30 * time.Second)
|
||||||
|
g.findNewShards(shardc)
|
||||||
|
|
||||||
|
// Note: while ticker is a rather naive approach to this problem,
|
||||||
|
// it actually simplies a few things. i.e. If we miss a new shard while
|
||||||
|
// AWS is resharding we'll pick it up max 30 seconds later.
|
||||||
|
|
||||||
|
// It might be worth refactoring this flow to allow the consumer to
|
||||||
|
// to notify the broker when a shard is closed. However, shards don't
|
||||||
|
// necessarily close at the same time, so we could potentially get a
|
||||||
|
// thundering heard of notifications from the consumer.
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
ticker.Stop()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
g.findNewShards(shardc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) {
|
||||||
|
return g.checkpoint.Get(streamName, shardID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *AllGroup) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
||||||
|
return g.checkpoint.Set(streamName, shardID, sequenceNumber)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findNewShards pulls the list of shards from the Kinesis API
|
||||||
|
// and uses a local cache to determine if we are already processing
|
||||||
|
// a particular shard.
|
||||||
|
func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) {
|
||||||
|
g.shardMu.Lock()
|
||||||
|
defer g.shardMu.Unlock()
|
||||||
|
|
||||||
|
g.logger.Log("[GROUP]", "fetching shards")
|
||||||
|
|
||||||
|
shards, err := listShards(g.ksis, g.streamName)
|
||||||
|
if err != nil {
|
||||||
|
g.logger.Log("[GROUP] error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, shard := range shards {
|
||||||
|
if _, ok := g.shards[*shard.ShardId]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
g.shards[*shard.ShardId] = shard
|
||||||
|
shardc <- shard
|
||||||
|
}
|
||||||
|
}
|
||||||
114
broker.go
114
broker.go
|
|
@ -1,114 +0,0 @@
|
||||||
package consumer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newBroker(
|
|
||||||
client kinesisiface.KinesisAPI,
|
|
||||||
streamName string,
|
|
||||||
shardc chan *kinesis.Shard,
|
|
||||||
logger Logger,
|
|
||||||
) *broker {
|
|
||||||
return &broker{
|
|
||||||
client: client,
|
|
||||||
shards: make(map[string]*kinesis.Shard),
|
|
||||||
streamName: streamName,
|
|
||||||
shardc: shardc,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// broker caches a local list of the shards we are already processing
|
|
||||||
// and routinely polls the stream looking for new shards to process
|
|
||||||
type broker struct {
|
|
||||||
client kinesisiface.KinesisAPI
|
|
||||||
streamName string
|
|
||||||
shardc chan *kinesis.Shard
|
|
||||||
logger Logger
|
|
||||||
|
|
||||||
shardMu sync.Mutex
|
|
||||||
shards map[string]*kinesis.Shard
|
|
||||||
}
|
|
||||||
|
|
||||||
// start is a blocking operation which will loop and attempt to find new
|
|
||||||
// shards on a regular cadence.
|
|
||||||
func (b *broker) start(ctx context.Context) {
|
|
||||||
b.findNewShards()
|
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
|
||||||
|
|
||||||
// Note: while ticker is a rather naive approach to this problem,
|
|
||||||
// it actually simplies a few things. i.e. If we miss a new shard while
|
|
||||||
// AWS is resharding we'll pick it up max 30 seconds later.
|
|
||||||
|
|
||||||
// It might be worth refactoring this flow to allow the consumer to
|
|
||||||
// to notify the broker when a shard is closed. However, shards don't
|
|
||||||
// necessarily close at the same time, so we could potentially get a
|
|
||||||
// thundering heard of notifications from the consumer.
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
ticker.Stop()
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
b.findNewShards()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// findNewShards pulls the list of shards from the Kinesis API
|
|
||||||
// and uses a local cache to determine if we are already processing
|
|
||||||
// a particular shard.
|
|
||||||
func (b *broker) findNewShards() {
|
|
||||||
b.shardMu.Lock()
|
|
||||||
defer b.shardMu.Unlock()
|
|
||||||
|
|
||||||
b.logger.Log("[BROKER]", "fetching shards")
|
|
||||||
|
|
||||||
shards, err := b.listShards()
|
|
||||||
if err != nil {
|
|
||||||
b.logger.Log("[BROKER]", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, shard := range shards {
|
|
||||||
if _, ok := b.shards[*shard.ShardId]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b.shards[*shard.ShardId] = shard
|
|
||||||
b.shardc <- shard
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// listShards pulls a list of shard IDs from the kinesis api
|
|
||||||
func (b *broker) listShards() ([]*kinesis.Shard, error) {
|
|
||||||
var ss []*kinesis.Shard
|
|
||||||
var listShardsInput = &kinesis.ListShardsInput{
|
|
||||||
StreamName: aws.String(b.streamName),
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
resp, err := b.client.ListShards(listShardsInput)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("ListShards error: %v", err)
|
|
||||||
}
|
|
||||||
ss = append(ss, resp.Shards...)
|
|
||||||
|
|
||||||
if resp.NextToken == nil {
|
|
||||||
return ss, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
listShardsInput = &kinesis.ListShardsInput{
|
|
||||||
NextToken: resp.NextToken,
|
|
||||||
StreamName: aws.String(b.streamName),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
32
consumer.go
32
consumer.go
|
|
@ -27,8 +27,8 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
c := &Consumer{
|
c := &Consumer{
|
||||||
streamName: streamName,
|
streamName: streamName,
|
||||||
initialShardIteratorType: kinesis.ShardIteratorTypeLatest,
|
initialShardIteratorType: kinesis.ShardIteratorTypeLatest,
|
||||||
checkpoint: &noopCheckpoint{},
|
|
||||||
counter: &noopCounter{},
|
counter: &noopCounter{},
|
||||||
|
checkpoint: &noopCheckpoint{},
|
||||||
logger: &noopLogger{
|
logger: &noopLogger{
|
||||||
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
||||||
},
|
},
|
||||||
|
|
@ -48,6 +48,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
||||||
c.client = kinesis.New(newSession)
|
c.client = kinesis.New(newSession)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// default group if none provided
|
||||||
|
if c.group == nil {
|
||||||
|
c.group = NewAllGroup(c.client, c.checkpoint, c.streamName, c.logger)
|
||||||
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -57,6 +62,7 @@ type Consumer struct {
|
||||||
initialShardIteratorType string
|
initialShardIteratorType string
|
||||||
client kinesisiface.KinesisAPI
|
client kinesisiface.KinesisAPI
|
||||||
logger Logger
|
logger Logger
|
||||||
|
group Group
|
||||||
checkpoint Checkpoint
|
checkpoint Checkpoint
|
||||||
counter Counter
|
counter Counter
|
||||||
}
|
}
|
||||||
|
|
@ -64,7 +70,6 @@ type Consumer struct {
|
||||||
// ScanFunc is the type of the function called for each message read
|
// ScanFunc is the type of the function called for each message read
|
||||||
// from the stream. The record argument contains the original record
|
// from the stream. The record argument contains the original record
|
||||||
// returned from the AWS Kinesis library.
|
// returned from the AWS Kinesis library.
|
||||||
//
|
|
||||||
// If an error is returned, scanning stops. The sole exception is when the
|
// If an error is returned, scanning stops. The sole exception is when the
|
||||||
// function returns the special value SkipCheckpoint.
|
// function returns the special value SkipCheckpoint.
|
||||||
type ScanFunc func(*Record) error
|
type ScanFunc func(*Record) error
|
||||||
|
|
@ -78,18 +83,16 @@ var SkipCheckpoint = errors.New("skip checkpoint")
|
||||||
// is passed through to each of the goroutines and called with each message pulled from
|
// is passed through to each of the goroutines and called with each message pulled from
|
||||||
// the stream.
|
// the stream.
|
||||||
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
var (
|
|
||||||
errc = make(chan error, 1)
|
|
||||||
shardc = make(chan *kinesis.Shard, 1)
|
|
||||||
broker = newBroker(c.client, c.streamName, shardc, c.logger)
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
go broker.start(ctx)
|
var (
|
||||||
|
errc = make(chan error, 1)
|
||||||
|
shardc = make(chan *kinesis.Shard, 1)
|
||||||
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
c.group.Start(ctx, shardc)
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
close(shardc)
|
close(shardc)
|
||||||
}()
|
}()
|
||||||
|
|
@ -110,7 +113,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
close(errc)
|
close(errc)
|
||||||
|
|
||||||
return <-errc
|
return <-errc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -118,7 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
// for each record and checkpoints the progress of scan.
|
// for each record and checkpoints the progress of scan.
|
||||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||||
// get last seq number from checkpoint
|
// get last seq number from checkpoint
|
||||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get checkpoint error: %v", err)
|
return fmt.Errorf("get checkpoint error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -129,9 +131,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
return fmt.Errorf("get shard iterator error: %v", err)
|
return fmt.Errorf("get shard iterator error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Log("[START]\t", shardID, lastSeqNum)
|
c.logger.Log("[CONSUMER] start scan:", shardID, lastSeqNum)
|
||||||
defer func() {
|
defer func() {
|
||||||
c.logger.Log("[STOP]\t", shardID)
|
c.logger.Log("[CONSUMER] stop scan:", shardID)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
@ -164,7 +166,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != SkipCheckpoint {
|
if err != SkipCheckpoint {
|
||||||
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -175,7 +177,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
}
|
}
|
||||||
|
|
||||||
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||||
c.logger.Log("[CLOSED]\t", shardID)
|
c.logger.Log("[CONSUMER] shard closed:", shardID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ Export the required environment vars for connecting to the Kinesis stream:
|
||||||
|
|
||||||
```
|
```
|
||||||
export AWS_PROFILE=
|
export AWS_PROFILE=
|
||||||
export AWS_REGION_NAME=
|
export AWS_REGION=
|
||||||
```
|
```
|
||||||
|
|
||||||
### Running the code
|
### Running the code
|
||||||
|
|
|
||||||
14
group.go
Normal file
14
group.go
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
package consumer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Group interface used to manage which shard to process
|
||||||
|
type Group interface {
|
||||||
|
Start(ctx context.Context, shardc chan *kinesis.Shard)
|
||||||
|
GetCheckpoint(streamName, shardID string) (string, error)
|
||||||
|
SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
||||||
|
}
|
||||||
34
kinesis.go
Normal file
34
kinesis.go
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
package consumer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// listShards pulls a list of shard IDs from the kinesis api
|
||||||
|
func listShards(ksis kinesisiface.KinesisAPI, streamName string) ([]*kinesis.Shard, error) {
|
||||||
|
var ss []*kinesis.Shard
|
||||||
|
var listShardsInput = &kinesis.ListShardsInput{
|
||||||
|
StreamName: aws.String(streamName),
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
resp, err := ksis.ListShards(listShardsInput)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ListShards error: %v", err)
|
||||||
|
}
|
||||||
|
ss = append(ss, resp.Shards...)
|
||||||
|
|
||||||
|
if resp.NextToken == nil {
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
listShardsInput = &kinesis.ListShardsInput{
|
||||||
|
NextToken: resp.NextToken,
|
||||||
|
StreamName: aws.String(streamName),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue