Use shard broker to start processing new shards

The addition of a shard broker will allow the consumer to be notified
when new shards are added to the stream so it can consume them.

Fixes: https://github.com/harlow/kinesis-consumer/issues/36
This commit is contained in:
Harlow Ward 2019-01-03 22:46:13 -08:00
parent c4f363a517
commit 7e72723168
5 changed files with 206 additions and 101 deletions

View file

@ -255,8 +255,8 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom
```go ```go
// logger // logger
log := &myLogger{ logger := &myLogger{
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags) logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags),
} }
// consumer // consumer

102
broker.go Normal file
View file

@ -0,0 +1,102 @@
package consumer
import (
"context"
"fmt"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
func newBroker(
client kinesisiface.KinesisAPI,
streamName string,
shardc chan *kinesis.Shard,
) *broker {
return &broker{
client: client,
shards: make(map[string]*kinesis.Shard),
streamName: streamName,
shardc: shardc,
}
}
type broker struct {
client kinesisiface.KinesisAPI
streamName string
shardc chan *kinesis.Shard
shardMu sync.Mutex
shards map[string]*kinesis.Shard
}
func (b *broker) shardLoop(ctx context.Context) {
b.fetchShards()
// add ticker, and cancellation
// also add signal to re-pull?
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
b.fetchShards()
}
}
}()
}
func (b *broker) fetchShards() {
shards, err := b.listShards()
if err != nil {
fmt.Println(err)
return
}
for _, shard := range shards {
if b.takeLease(shard) {
b.shardc <- shard
}
}
}
func (b *broker) listShards() ([]*kinesis.Shard, error) {
var ss []*kinesis.Shard
var listShardsInput = &kinesis.ListShardsInput{
StreamName: aws.String(b.streamName),
}
for {
resp, err := b.client.ListShards(listShardsInput)
if err != nil {
return nil, fmt.Errorf("ListShards error: %v", err)
}
ss = append(ss, resp.Shards...)
if resp.NextToken == nil {
return ss, nil
}
listShardsInput = &kinesis.ListShardsInput{
NextToken: resp.NextToken,
StreamName: aws.String(b.streamName),
}
}
}
func (b *broker) takeLease(shard *kinesis.Shard) bool {
b.shardMu.Lock()
defer b.shardMu.Unlock()
if _, ok := b.shards[*shard.ShardId]; ok {
return false
}
b.shards[*shard.ShardId] = shard
return true
}

View file

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"sync"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
@ -70,7 +69,7 @@ type Consumer struct {
// function returns the special value SkipCheckpoint. // function returns the special value SkipCheckpoint.
type ScanFunc func(*Record) error type ScanFunc func(*Record) error
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that // SkipCheckpoint is used as a return value from ScanFunc to indicate that
// the current checkpoint should be skipped skipped. It is not returned // the current checkpoint should be skipped skipped. It is not returned
// as an error by any function. // as an error by any function.
var SkipCheckpoint = errors.New("skip checkpoint") var SkipCheckpoint = errors.New("skip checkpoint")
@ -79,51 +78,45 @@ var SkipCheckpoint = errors.New("skip checkpoint")
// is passed through to each of the goroutines and called with each message pulled from // is passed through to each of the goroutines and called with each message pulled from
// the stream. // the stream.
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
var (
errc = make(chan error, 1)
shardc = make(chan *kinesis.Shard, 1)
broker = newBroker(c.client, c.streamName, shardc)
)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
// get shard ids go func() {
shardIDs, err := c.getShardIDs(c.streamName) broker.shardLoop(ctx)
if err != nil {
return fmt.Errorf("get shards error: %v", err)
}
if len(shardIDs) == 0 { <-ctx.Done()
return fmt.Errorf("no shards available") close(shardc)
} }()
var ( // process each of the shards
wg sync.WaitGroup for shard := range shardc {
errc = make(chan error, 1)
)
wg.Add(len(shardIDs))
// process each shard in a separate goroutine
for _, shardID := range shardIDs {
go func(shardID string) { go func(shardID string) {
defer wg.Done()
if err := c.ScanShard(ctx, shardID, fn); err != nil { if err := c.ScanShard(ctx, shardID, fn); err != nil {
cancel()
select { select {
case errc <- fmt.Errorf("shard %s error: %v", shardID, err): case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
// first error to occur // first error to occur
cancel()
default: default:
// error has already occured // error has already occured
} }
return
} }
}(shardID) }(aws.StringValue(shard.ShardId))
} }
wg.Wait()
close(errc) close(errc)
return <-errc return <-errc
} }
// ScanShard loops over records on a specific shard, calls the ScanFunc callback // ScanShard loops over records on a specific shard, calls the callback func
// func for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
// get last seq number from checkpoint // get last seq number from checkpoint
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
@ -137,7 +130,10 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
return fmt.Errorf("get shard iterator error: %v", err) return fmt.Errorf("get shard iterator error: %v", err)
} }
c.logger.Log("scanning", shardID, lastSeqNum) c.logger.Log("[START]\t", shardID, lastSeqNum)
defer func() {
c.logger.Log("[STOP]\t", shardID)
}()
for { for {
select { select {
@ -148,8 +144,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
ShardIterator: shardIterator, ShardIterator: shardIterator,
}) })
// often we can recover from GetRecords error by getting a // attempt to recover from GetRecords error by getting new shard iterator
// new shard iterator, else return error
if err != nil { if err != nil {
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil { if err != nil {
@ -181,6 +176,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
} }
if isShardClosed(resp.NextShardIterator, shardIterator) { if isShardClosed(resp.NextShardIterator, shardIterator) {
c.logger.Log("[CLOSED]\t", shardID)
return nil return nil
} }
@ -193,32 +189,6 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator return nextShardIterator == nil || currentShardIterator == nextShardIterator
} }
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
var ss []string
var listShardsInput = &kinesis.ListShardsInput{
StreamName: aws.String(streamName),
}
for {
resp, err := c.client.ListShards(listShardsInput)
if err != nil {
return nil, fmt.Errorf("ListShards error: %v", err)
}
for _, shard := range resp.Shards {
ss = append(ss, *shard.ShardId)
}
if resp.NextToken == nil {
return ss, nil
}
listShardsInput = &kinesis.ListShardsInput{
NextToken: resp.NextToken,
}
}
}
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) { func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{ params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID), ShardId: aws.String(shardID),

View file

@ -63,24 +63,27 @@ func TestScan(t *testing.T) {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
var resultData string var (
var fnCallCounter int ctx, cancel = context.WithCancel(context.Background())
res string
)
var fn = func(r *Record) error { var fn = func(r *Record) error {
fnCallCounter++ res += string(r.Data)
resultData += string(r.Data)
if string(r.Data) == "lastData" {
cancel()
}
return nil return nil
} }
if err := c.Scan(context.Background(), fn); err != nil { if err := c.Scan(ctx, fn); err != nil {
t.Errorf("scan shard error expected nil. got %v", err) t.Errorf("scan returned unexpected error %v", err)
} }
if resultData != "firstDatalastData" { if res != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "FirstLast", resultData) t.Errorf("callback error expected %s, got %s", "firstDatalastData", res)
}
if fnCallCounter != 2 {
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
} }
if val := ctr.counter; val != 2 { if val := ctr.counter; val != 2 {
@ -146,15 +149,23 @@ func TestScanShard(t *testing.T) {
} }
// callback fn appends record data // callback fn appends record data
var res string var (
ctx, cancel = context.WithCancel(context.Background())
res string
)
var fn = func(r *Record) error { var fn = func(r *Record) error {
res += string(r.Data) res += string(r.Data)
if string(r.Data) == "lastData" {
cancel()
}
return nil return nil
} }
// scan shard if err := c.Scan(ctx, fn); err != nil {
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { t.Errorf("scan returned unexpected error %v", err)
t.Fatalf("scan shard error: %v", err)
} }
// runs callback func // runs callback func
@ -236,14 +247,18 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
t.Fatalf("new consumer error: %v", err) t.Fatalf("new consumer error: %v", err)
} }
var ctx, cancel = context.WithCancel(context.Background())
var fn = func(r *Record) error { var fn = func(r *Record) error {
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" { if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
cancel()
return SkipCheckpoint return SkipCheckpoint
} }
return nil return nil
} }
err = c.ScanShard(context.Background(), "myShard", fn) err = c.ScanShard(ctx, "myShard", fn)
if err != nil { if err != nil {
t.Fatalf("scan shard error: %v", err) t.Fatalf("scan shard error: %v", err)
} }
@ -254,35 +269,35 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
} }
} }
func TestScanShard_ShardIsClosed(t *testing.T) { // func TestScanShard_ShardIsClosed(t *testing.T) {
var client = &kinesisClientMock{ // var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { // getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{ // return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), // ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil // }, nil
}, // },
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { // getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{ // return &kinesis.GetRecordsOutput{
NextShardIterator: nil, // NextShardIterator: nil,
Records: make([]*Record, 0), // Records: make([]*Record, 0),
}, nil // }, nil
}, // },
} // }
c, err := New("myStreamName", WithClient(client)) // c, err := New("myStreamName", WithClient(client))
if err != nil { // if err != nil {
t.Fatalf("new consumer error: %v", err) // t.Fatalf("new consumer error: %v", err)
} // }
var fn = func(r *Record) error { // var fn = func(r *Record) error {
return nil // return nil
} // }
err = c.ScanShard(context.Background(), "myShard", fn) // err = c.ScanShard(context.Background(), "myShard", fn)
if err != nil { // if err != nil {
t.Fatalf("scan shard error: %v", err) // t.Fatalf("scan shard error: %v", err)
} // }
} // }
type kinesisClientMock struct { type kinesisClientMock struct {
kinesisiface.KinesisAPI kinesisiface.KinesisAPI

View file

@ -12,6 +12,16 @@ import (
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis" checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
) )
// A myLogger provides a minimalistic logger satisfying the Logger interface.
type myLogger struct {
logger *log.Logger
}
// Log logs the parameters to the stdlib logger. See log.Println.
func (l *myLogger) Log(args ...interface{}) {
l.logger.Println(args...)
}
func main() { func main() {
var ( var (
app = flag.String("app", "", "Consumer app name") app = flag.String("app", "", "Consumer app name")
@ -25,9 +35,16 @@ func main() {
log.Fatalf("checkpoint error: %v", err) log.Fatalf("checkpoint error: %v", err)
} }
// logger
logger := &myLogger{
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags),
}
// consumer // consumer
c, err := consumer.New( c, err := consumer.New(
*stream, consumer.WithCheckpoint(ck), *stream,
consumer.WithCheckpoint(ck),
consumer.WithLogger(logger),
) )
if err != nil { if err != nil {
log.Fatalf("consumer error: %v", err) log.Fatalf("consumer error: %v", err)
@ -42,6 +59,7 @@ func main() {
go func() { go func() {
<-signals <-signals
fmt.Println("caught exit signal, cancelling context!")
cancel() cancel()
}() }()