Compare commits
4 commits
master
...
hw-group-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0328cba5c9 | ||
|
|
4fd29c54ff | ||
|
|
2ab5ec4031 | ||
|
|
bd42663013 |
6 changed files with 156 additions and 130 deletions
90
allgroup.go
Normal file
90
allgroup.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||
)
|
||||
|
||||
func NewAllGroup(ksis kinesisiface.KinesisAPI, ck Checkpoint, streamName string, logger Logger) *AllGroup {
|
||||
return &AllGroup{
|
||||
ksis: ksis,
|
||||
shards: make(map[string]*kinesis.Shard),
|
||||
streamName: streamName,
|
||||
logger: logger,
|
||||
checkpoint: ck,
|
||||
}
|
||||
}
|
||||
|
||||
// AllGroup caches a local list of the shards we are already processing
|
||||
// and routinely polls the stream looking for new shards to process
|
||||
type AllGroup struct {
|
||||
ksis kinesisiface.KinesisAPI
|
||||
streamName string
|
||||
logger Logger
|
||||
checkpoint Checkpoint
|
||||
|
||||
shardMu sync.Mutex
|
||||
shards map[string]*kinesis.Shard
|
||||
}
|
||||
|
||||
// start is a blocking operation which will loop and attempt to find new
|
||||
// shards on a regular cadence.
|
||||
func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) {
|
||||
var ticker = time.NewTicker(30 * time.Second)
|
||||
g.findNewShards(shardc)
|
||||
|
||||
// Note: while ticker is a rather naive approach to this problem,
|
||||
// it actually simplies a few things. i.e. If we miss a new shard while
|
||||
// AWS is resharding we'll pick it up max 30 seconds later.
|
||||
|
||||
// It might be worth refactoring this flow to allow the consumer to
|
||||
// to notify the broker when a shard is closed. However, shards don't
|
||||
// necessarily close at the same time, so we could potentially get a
|
||||
// thundering heard of notifications from the consumer.
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
g.findNewShards(shardc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) {
|
||||
return g.checkpoint.Get(streamName, shardID)
|
||||
}
|
||||
|
||||
func (g *AllGroup) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
||||
return g.checkpoint.Set(streamName, shardID, sequenceNumber)
|
||||
}
|
||||
|
||||
// findNewShards pulls the list of shards from the Kinesis API
|
||||
// and uses a local cache to determine if we are already processing
|
||||
// a particular shard.
|
||||
func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) {
|
||||
g.shardMu.Lock()
|
||||
defer g.shardMu.Unlock()
|
||||
|
||||
g.logger.Log("[GROUP]", "fetching shards")
|
||||
|
||||
shards, err := listShards(g.ksis, g.streamName)
|
||||
if err != nil {
|
||||
g.logger.Log("[GROUP] error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, shard := range shards {
|
||||
if _, ok := g.shards[*shard.ShardId]; ok {
|
||||
continue
|
||||
}
|
||||
g.shards[*shard.ShardId] = shard
|
||||
shardc <- shard
|
||||
}
|
||||
}
|
||||
114
broker.go
114
broker.go
|
|
@ -1,114 +0,0 @@
|
|||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||
)
|
||||
|
||||
func newBroker(
|
||||
client kinesisiface.KinesisAPI,
|
||||
streamName string,
|
||||
shardc chan *kinesis.Shard,
|
||||
logger Logger,
|
||||
) *broker {
|
||||
return &broker{
|
||||
client: client,
|
||||
shards: make(map[string]*kinesis.Shard),
|
||||
streamName: streamName,
|
||||
shardc: shardc,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// broker caches a local list of the shards we are already processing
|
||||
// and routinely polls the stream looking for new shards to process
|
||||
type broker struct {
|
||||
client kinesisiface.KinesisAPI
|
||||
streamName string
|
||||
shardc chan *kinesis.Shard
|
||||
logger Logger
|
||||
|
||||
shardMu sync.Mutex
|
||||
shards map[string]*kinesis.Shard
|
||||
}
|
||||
|
||||
// start is a blocking operation which will loop and attempt to find new
|
||||
// shards on a regular cadence.
|
||||
func (b *broker) start(ctx context.Context) {
|
||||
b.findNewShards()
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
|
||||
// Note: while ticker is a rather naive approach to this problem,
|
||||
// it actually simplies a few things. i.e. If we miss a new shard while
|
||||
// AWS is resharding we'll pick it up max 30 seconds later.
|
||||
|
||||
// It might be worth refactoring this flow to allow the consumer to
|
||||
// to notify the broker when a shard is closed. However, shards don't
|
||||
// necessarily close at the same time, so we could potentially get a
|
||||
// thundering heard of notifications from the consumer.
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
b.findNewShards()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findNewShards pulls the list of shards from the Kinesis API
|
||||
// and uses a local cache to determine if we are already processing
|
||||
// a particular shard.
|
||||
func (b *broker) findNewShards() {
|
||||
b.shardMu.Lock()
|
||||
defer b.shardMu.Unlock()
|
||||
|
||||
b.logger.Log("[BROKER]", "fetching shards")
|
||||
|
||||
shards, err := b.listShards()
|
||||
if err != nil {
|
||||
b.logger.Log("[BROKER]", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, shard := range shards {
|
||||
if _, ok := b.shards[*shard.ShardId]; ok {
|
||||
continue
|
||||
}
|
||||
b.shards[*shard.ShardId] = shard
|
||||
b.shardc <- shard
|
||||
}
|
||||
}
|
||||
|
||||
// listShards pulls a list of shard IDs from the kinesis api
|
||||
func (b *broker) listShards() ([]*kinesis.Shard, error) {
|
||||
var ss []*kinesis.Shard
|
||||
var listShardsInput = &kinesis.ListShardsInput{
|
||||
StreamName: aws.String(b.streamName),
|
||||
}
|
||||
|
||||
for {
|
||||
resp, err := b.client.ListShards(listShardsInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ListShards error: %v", err)
|
||||
}
|
||||
ss = append(ss, resp.Shards...)
|
||||
|
||||
if resp.NextToken == nil {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
listShardsInput = &kinesis.ListShardsInput{
|
||||
NextToken: resp.NextToken,
|
||||
StreamName: aws.String(b.streamName),
|
||||
}
|
||||
}
|
||||
}
|
||||
32
consumer.go
32
consumer.go
|
|
@ -27,8 +27,8 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
|||
c := &Consumer{
|
||||
streamName: streamName,
|
||||
initialShardIteratorType: kinesis.ShardIteratorTypeLatest,
|
||||
checkpoint: &noopCheckpoint{},
|
||||
counter: &noopCounter{},
|
||||
checkpoint: &noopCheckpoint{},
|
||||
logger: &noopLogger{
|
||||
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
||||
},
|
||||
|
|
@ -48,6 +48,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
|
|||
c.client = kinesis.New(newSession)
|
||||
}
|
||||
|
||||
// default group if none provided
|
||||
if c.group == nil {
|
||||
c.group = NewAllGroup(c.client, c.checkpoint, c.streamName, c.logger)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
|
@ -57,6 +62,7 @@ type Consumer struct {
|
|||
initialShardIteratorType string
|
||||
client kinesisiface.KinesisAPI
|
||||
logger Logger
|
||||
group Group
|
||||
checkpoint Checkpoint
|
||||
counter Counter
|
||||
}
|
||||
|
|
@ -64,7 +70,6 @@ type Consumer struct {
|
|||
// ScanFunc is the type of the function called for each message read
|
||||
// from the stream. The record argument contains the original record
|
||||
// returned from the AWS Kinesis library.
|
||||
//
|
||||
// If an error is returned, scanning stops. The sole exception is when the
|
||||
// function returns the special value SkipCheckpoint.
|
||||
type ScanFunc func(*Record) error
|
||||
|
|
@ -78,18 +83,16 @@ var SkipCheckpoint = errors.New("skip checkpoint")
|
|||
// is passed through to each of the goroutines and called with each message pulled from
|
||||
// the stream.
|
||||
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||
var (
|
||||
errc = make(chan error, 1)
|
||||
shardc = make(chan *kinesis.Shard, 1)
|
||||
broker = newBroker(c.client, c.streamName, shardc, c.logger)
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go broker.start(ctx)
|
||||
var (
|
||||
errc = make(chan error, 1)
|
||||
shardc = make(chan *kinesis.Shard, 1)
|
||||
)
|
||||
|
||||
go func() {
|
||||
c.group.Start(ctx, shardc)
|
||||
<-ctx.Done()
|
||||
close(shardc)
|
||||
}()
|
||||
|
|
@ -110,7 +113,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
}
|
||||
|
||||
close(errc)
|
||||
|
||||
return <-errc
|
||||
}
|
||||
|
||||
|
|
@ -118,7 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
|||
// for each record and checkpoints the progress of scan.
|
||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||
// get last seq number from checkpoint
|
||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
||||
lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get checkpoint error: %v", err)
|
||||
}
|
||||
|
|
@ -129,9 +131,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
return fmt.Errorf("get shard iterator error: %v", err)
|
||||
}
|
||||
|
||||
c.logger.Log("[START]\t", shardID, lastSeqNum)
|
||||
c.logger.Log("[CONSUMER] start scan:", shardID, lastSeqNum)
|
||||
defer func() {
|
||||
c.logger.Log("[STOP]\t", shardID)
|
||||
c.logger.Log("[CONSUMER] stop scan:", shardID)
|
||||
}()
|
||||
|
||||
for {
|
||||
|
|
@ -164,7 +166,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
}
|
||||
|
||||
if err != SkipCheckpoint {
|
||||
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -175,7 +177,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
}
|
||||
|
||||
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||
c.logger.Log("[CLOSED]\t", shardID)
|
||||
c.logger.Log("[CONSUMER] shard closed:", shardID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ Export the required environment vars for connecting to the Kinesis stream:
|
|||
|
||||
```
|
||||
export AWS_PROFILE=
|
||||
export AWS_REGION_NAME=
|
||||
export AWS_REGION=
|
||||
```
|
||||
|
||||
### Running the code
|
||||
|
|
|
|||
14
group.go
Normal file
14
group.go
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||
)
|
||||
|
||||
// Group interface used to manage which shard to process
|
||||
type Group interface {
|
||||
Start(ctx context.Context, shardc chan *kinesis.Shard)
|
||||
GetCheckpoint(streamName, shardID string) (string, error)
|
||||
SetCheckpoint(streamName, shardID, sequenceNumber string) error
|
||||
}
|
||||
34
kinesis.go
Normal file
34
kinesis.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package consumer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||
)
|
||||
|
||||
// listShards pulls a list of shard IDs from the kinesis api
|
||||
func listShards(ksis kinesisiface.KinesisAPI, streamName string) ([]*kinesis.Shard, error) {
|
||||
var ss []*kinesis.Shard
|
||||
var listShardsInput = &kinesis.ListShardsInput{
|
||||
StreamName: aws.String(streamName),
|
||||
}
|
||||
|
||||
for {
|
||||
resp, err := ksis.ListShards(listShardsInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ListShards error: %v", err)
|
||||
}
|
||||
ss = append(ss, resp.Shards...)
|
||||
|
||||
if resp.NextToken == nil {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
listShardsInput = &kinesis.ListShardsInput{
|
||||
NextToken: resp.NextToken,
|
||||
StreamName: aws.String(streamName),
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue