Use shard broker to monitor and process new shards (#85)
* Use shard broker to start processing new shards The addition of a shard broker will allow the consumer to be notified when new shards are added to the stream so it can consume them. Fixes: https://github.com/harlow/kinesis-consumer/issues/36
This commit is contained in:
parent
c4f363a517
commit
97fe4e66ff
7 changed files with 210 additions and 106 deletions
|
|
@ -20,7 +20,9 @@ Get the package source:
|
|||
|
||||
The consumer leverages a handler func that accepts a Kinesis record. The `Scan` method will consume all shards concurrently and call the callback func as it receives records from the stream.
|
||||
|
||||
_Important: The default Log, Counter, and Checkpoint are no-op which means no logs, counts, or checkpoints will be emitted when scanning the stream. See the options below to override these defaults._
|
||||
_Important 1: The `Scan` func will also poll the stream to check for new shards, it will automatcially start consuming new shards added to the stream._
|
||||
|
||||
_Important 2: The default Log, Counter, and Checkpoint are no-op which means no logs, counts, or checkpoints will be emitted when scanning the stream. See the options below to override these defaults._
|
||||
|
||||
```go
|
||||
import(
|
||||
|
|
@ -255,8 +257,8 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom
|
|||
|
||||
```go
|
||||
// logger
|
||||
log := &myLogger{
|
||||
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
|
||||
logger := &myLogger{
|
||||
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags),
|
||||
}
|
||||
|
||||
// consumer
|
||||
|
|
|
|||
114
broker.go
Normal file
114
broker.go
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||
)
|
||||
|
||||
func newBroker(
|
||||
client kinesisiface.KinesisAPI,
|
||||
streamName string,
|
||||
shardc chan *kinesis.Shard,
|
||||
logger Logger,
|
||||
) *broker {
|
||||
return &broker{
|
||||
client: client,
|
||||
shards: make(map[string]*kinesis.Shard),
|
||||
streamName: streamName,
|
||||
shardc: shardc,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// broker caches a local list of the shards we are already processing
|
||||
// and routinely polls the stream looking for new shards to process
|
||||
type broker struct {
|
||||
client kinesisiface.KinesisAPI
|
||||
streamName string
|
||||
shardc chan *kinesis.Shard
|
||||
logger Logger
|
||||
|
||||
shardMu sync.Mutex
|
||||
shards map[string]*kinesis.Shard
|
||||
}
|
||||
|
||||
// start is a blocking operation which will loop and attempt to find new
|
||||
// shards on a regular cadence.
|
||||
func (b *broker) start(ctx context.Context) {
|
||||
b.findNewShards()
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
|
||||
// Note: while ticker is a rather naive approach to this problem,
|
||||
// it actually simplies a few things. i.e. If we miss a new shard while
|
||||
// AWS is resharding we'll pick it up max 30 seconds later.
|
||||
|
||||
// It might be worth refactoring this flow to allow the consumer to
|
||||
// to notify the broker when a shard is closed. However, shards don't
|
||||
// necessarily close at the same time, so we could potentially get a
|
||||
// thundering heard of notifications from the consumer.
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
b.findNewShards()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findNewShards pulls the list of shards from the Kinesis API
|
||||
// and uses a local cache to determine if we are already processing
|
||||
// a particular shard.
|
||||
func (b *broker) findNewShards() {
|
||||
b.shardMu.Lock()
|
||||
defer b.shardMu.Unlock()
|
||||
|
||||
b.logger.Log("[BROKER]", "fetching shards")
|
||||
|
||||
shards, err := b.listShards()
|
||||
if err != nil {
|
||||
b.logger.Log("[BROKER]", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, shard := range shards {
|
||||
if _, ok := b.shards[*shard.ShardId]; ok {
|
||||
continue
|
||||
}
|
||||
b.shards[*shard.ShardId] = shard
|
||||
b.shardc <- shard
|
||||
}
|
||||
}
|
||||
|
||||
// listShards pulls a list of shard IDs from the kinesis api
|
||||
func (b *broker) listShards() ([]*kinesis.Shard, error) {
|
||||
var ss []*kinesis.Shard
|
||||
var listShardsInput = &kinesis.ListShardsInput{
|
||||
StreamName: aws.String(b.streamName),
|
||||
}
|
||||
|
||||
for {
|
||||
resp, err := b.client.ListShards(listShardsInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ListShards error: %v", err)
|
||||
}
|
||||
ss = append(ss, resp.Shards...)
|
||||
|
||||
if resp.NextToken == nil {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
listShardsInput = &kinesis.ListShardsInput{
|
||||
NextToken: resp.NextToken,
|
||||
StreamName: aws.String(b.streamName),
|
||||
}
|
||||
}
|
||||
}
|
||||
79
consumer.go
79
consumer.go
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
|
|
@ -70,7 +69,7 @@ type Consumer struct {
|
|||
// function returns the special value SkipCheckpoint.
|
||||
type ScanFunc func(*Record) error
|
||||
|
||||
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
|
||||
// SkipCheckpoint is used as a return value from ScanFunc to indicate that
|
||||
// the current checkpoint should be skipped skipped. It is not returned
|
||||
// as an error by any function.
|
||||
var SkipCheckpoint = errors.New("skip checkpoint")
|
||||
|
|
@ -79,51 +78,44 @@ var SkipCheckpoint = errors.New("skip checkpoint")
|
|||
// is passed through to each of the goroutines and called with each message pulled from
|
||||
// the stream.
|
||||
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||
var (
|
||||
errc = make(chan error, 1)
|
||||
shardc = make(chan *kinesis.Shard, 1)
|
||||
broker = newBroker(c.client, c.streamName, shardc, c.logger)
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// get shard ids
|
||||
shardIDs, err := c.getShardIDs(c.streamName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get shards error: %v", err)
|
||||
}
|
||||
go broker.start(ctx)
|
||||
|
||||
if len(shardIDs) == 0 {
|
||||
return fmt.Errorf("no shards available")
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
close(shardc)
|
||||
}()
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
errc = make(chan error, 1)
|
||||
)
|
||||
wg.Add(len(shardIDs))
|
||||
|
||||
// process each shard in a separate goroutine
|
||||
for _, shardID := range shardIDs {
|
||||
// process each of the shards
|
||||
for shard := range shardc {
|
||||
go func(shardID string) {
|
||||
defer wg.Done()
|
||||
|
||||
if err := c.ScanShard(ctx, shardID, fn); err != nil {
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
|
||||
// first error to occur
|
||||
cancel()
|
||||
default:
|
||||
// error has already occured
|
||||
}
|
||||
}
|
||||
}(shardID)
|
||||
}(aws.StringValue(shard.ShardId))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errc)
|
||||
|
||||
return <-errc
|
||||
}
|
||||
|
||||
// ScanShard loops over records on a specific shard, calls the ScanFunc callback
|
||||
// func for each record and checkpoints the progress of scan.
|
||||
// ScanShard loops over records on a specific shard, calls the callback func
|
||||
// for each record and checkpoints the progress of scan.
|
||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||
// get last seq number from checkpoint
|
||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
||||
|
|
@ -137,7 +129,10 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
return fmt.Errorf("get shard iterator error: %v", err)
|
||||
}
|
||||
|
||||
c.logger.Log("scanning", shardID, lastSeqNum)
|
||||
c.logger.Log("[START]\t", shardID, lastSeqNum)
|
||||
defer func() {
|
||||
c.logger.Log("[STOP]\t", shardID)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
|
|
@ -148,8 +143,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
ShardIterator: shardIterator,
|
||||
})
|
||||
|
||||
// often we can recover from GetRecords error by getting a
|
||||
// new shard iterator, else return error
|
||||
// attempt to recover from GetRecords error by getting new shard iterator
|
||||
if err != nil {
|
||||
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
|
||||
if err != nil {
|
||||
|
|
@ -181,6 +175,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
|||
}
|
||||
|
||||
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||
c.logger.Log("[CLOSED]\t", shardID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -193,32 +188,6 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
|
|||
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
||||
}
|
||||
|
||||
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
||||
var ss []string
|
||||
var listShardsInput = &kinesis.ListShardsInput{
|
||||
StreamName: aws.String(streamName),
|
||||
}
|
||||
|
||||
for {
|
||||
resp, err := c.client.ListShards(listShardsInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ListShards error: %v", err)
|
||||
}
|
||||
|
||||
for _, shard := range resp.Shards {
|
||||
ss = append(ss, *shard.ShardId)
|
||||
}
|
||||
|
||||
if resp.NextToken == nil {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
listShardsInput = &kinesis.ListShardsInput{
|
||||
NextToken: resp.NextToken,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
|
||||
params := &kinesis.GetShardIteratorInput{
|
||||
ShardId: aws.String(shardID),
|
||||
|
|
|
|||
|
|
@ -63,27 +63,30 @@ func TestScan(t *testing.T) {
|
|||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
var resultData string
|
||||
var fnCallCounter int
|
||||
var (
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
res string
|
||||
)
|
||||
|
||||
var fn = func(r *Record) error {
|
||||
fnCallCounter++
|
||||
resultData += string(r.Data)
|
||||
res += string(r.Data)
|
||||
|
||||
if string(r.Data) == "lastData" {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.Scan(context.Background(), fn); err != nil {
|
||||
t.Errorf("scan shard error expected nil. got %v", err)
|
||||
if err := c.Scan(ctx, fn); err != nil {
|
||||
t.Errorf("scan returned unexpected error %v", err)
|
||||
}
|
||||
|
||||
if resultData != "firstDatalastData" {
|
||||
t.Errorf("callback error expected %s, got %s", "FirstLast", resultData)
|
||||
if res != "firstDatalastData" {
|
||||
t.Errorf("callback error expected %s, got %s", "firstDatalastData", res)
|
||||
}
|
||||
|
||||
if fnCallCounter != 2 {
|
||||
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
||||
}
|
||||
|
||||
if val := ctr.counter; val != 2 {
|
||||
if val := ctr.Get(); val != 2 {
|
||||
t.Errorf("counter error expected %d, got %d", 2, val)
|
||||
}
|
||||
|
||||
|
|
@ -93,29 +96,6 @@ func TestScan(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestScan_NoShardsAvailable(t *testing.T) {
|
||||
client := &kinesisClientMock{
|
||||
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
|
||||
return &kinesis.ListShardsOutput{
|
||||
Shards: make([]*kinesis.Shard, 0),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
var fn = func(r *Record) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c, err := New("myStreamName", WithClient(client))
|
||||
if err != nil {
|
||||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
if err := c.Scan(context.Background(), fn); err == nil {
|
||||
t.Errorf("scan shard error expected not nil. got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanShard(t *testing.T) {
|
||||
var client = &kinesisClientMock{
|
||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||
|
|
@ -146,15 +126,23 @@ func TestScanShard(t *testing.T) {
|
|||
}
|
||||
|
||||
// callback fn appends record data
|
||||
var res string
|
||||
var (
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
res string
|
||||
)
|
||||
|
||||
var fn = func(r *Record) error {
|
||||
res += string(r.Data)
|
||||
|
||||
if string(r.Data) == "lastData" {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// scan shard
|
||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
||||
t.Fatalf("scan shard error: %v", err)
|
||||
if err := c.ScanShard(ctx, "myShard", fn); err != nil {
|
||||
t.Errorf("scan returned unexpected error %v", err)
|
||||
}
|
||||
|
||||
// runs callback func
|
||||
|
|
@ -163,7 +151,7 @@ func TestScanShard(t *testing.T) {
|
|||
}
|
||||
|
||||
// increments counter
|
||||
if val := ctr.counter; val != 2 {
|
||||
if val := ctr.Get(); val != 2 {
|
||||
t.Fatalf("counter error expected %d, got %d", 2, val)
|
||||
}
|
||||
|
||||
|
|
@ -236,14 +224,18 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
|
|||
t.Fatalf("new consumer error: %v", err)
|
||||
}
|
||||
|
||||
var ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var fn = func(r *Record) error {
|
||||
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
|
||||
cancel()
|
||||
return SkipCheckpoint
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = c.ScanShard(context.Background(), "myShard", fn)
|
||||
err = c.ScanShard(ctx, "myShard", fn)
|
||||
if err != nil {
|
||||
t.Fatalf("scan shard error: %v", err)
|
||||
}
|
||||
|
|
@ -329,8 +321,19 @@ func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) {
|
|||
// implementation of counter
|
||||
type fakeCounter struct {
|
||||
counter int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (fc *fakeCounter) Get() int64 {
|
||||
fc.mu.Lock()
|
||||
defer fc.mu.Unlock()
|
||||
|
||||
return fc.counter
|
||||
}
|
||||
|
||||
func (fc *fakeCounter) Add(streamName string, count int64) {
|
||||
fc.mu.Lock()
|
||||
defer fc.mu.Unlock()
|
||||
|
||||
fc.counter += count
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@ Read records from the Kinesis stream
|
|||
Export the required environment vars for connecting to the Kinesis stream and Redis for checkpoint:
|
||||
|
||||
```
|
||||
export AWS_ACCESS_KEY=
|
||||
export AWS_PROFILE=
|
||||
export AWS_REGION=
|
||||
export AWS_SECRET_KEY=
|
||||
export REDIS_URL=
|
||||
```
|
||||
|
||||
### Run the consumer
|
||||
|
||||
$ go run main.go --app appName --stream streamName
|
||||
$ go run main.go --app appName --stream streamName
|
||||
|
|
|
|||
|
|
@ -12,6 +12,16 @@ import (
|
|||
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
|
||||
)
|
||||
|
||||
// A myLogger provides a minimalistic logger satisfying the Logger interface.
|
||||
type myLogger struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Log logs the parameters to the stdlib logger. See log.Println.
|
||||
func (l *myLogger) Log(args ...interface{}) {
|
||||
l.logger.Println(args...)
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
app = flag.String("app", "", "Consumer app name")
|
||||
|
|
@ -25,9 +35,16 @@ func main() {
|
|||
log.Fatalf("checkpoint error: %v", err)
|
||||
}
|
||||
|
||||
// logger
|
||||
logger := &myLogger{
|
||||
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags),
|
||||
}
|
||||
|
||||
// consumer
|
||||
c, err := consumer.New(
|
||||
*stream, consumer.WithCheckpoint(ck),
|
||||
*stream,
|
||||
consumer.WithCheckpoint(ck),
|
||||
consumer.WithLogger(logger),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("consumer error: %v", err)
|
||||
|
|
@ -42,6 +59,7 @@ func main() {
|
|||
|
||||
go func() {
|
||||
<-signals
|
||||
fmt.Println("caught exit signal, cancelling context!")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,8 @@ A prepopulated file with JSON users is available on S3 for seeing the stream.
|
|||
Export the required environment vars for connecting to the Kinesis stream:
|
||||
|
||||
```
|
||||
export AWS_ACCESS_KEY=
|
||||
export AWS_PROFILE=
|
||||
export AWS_REGION_NAME=
|
||||
export AWS_SECRET_KEY=
|
||||
```
|
||||
|
||||
### Running the code
|
||||
|
|
|
|||
Loading…
Reference in a new issue