Use shard broker to start processing new shards
The addition of a shard broker will allow the consumer to be notified when new shards are added to the stream so it can consume them. Fixes: https://github.com/harlow/kinesis-consumer/issues/36
This commit is contained in:
parent
c4f363a517
commit
7e72723168
5 changed files with 206 additions and 101 deletions
|
|
@ -255,8 +255,8 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// logger
|
// logger
|
||||||
log := &myLogger{
|
logger := &myLogger{
|
||||||
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
|
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags),
|
||||||
}
|
}
|
||||||
|
|
||||||
// consumer
|
// consumer
|
||||||
|
|
|
||||||
102
broker.go
Normal file
102
broker.go
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
package consumer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||||
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newBroker(
|
||||||
|
client kinesisiface.KinesisAPI,
|
||||||
|
streamName string,
|
||||||
|
shardc chan *kinesis.Shard,
|
||||||
|
) *broker {
|
||||||
|
return &broker{
|
||||||
|
client: client,
|
||||||
|
shards: make(map[string]*kinesis.Shard),
|
||||||
|
streamName: streamName,
|
||||||
|
shardc: shardc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type broker struct {
|
||||||
|
client kinesisiface.KinesisAPI
|
||||||
|
streamName string
|
||||||
|
shardc chan *kinesis.Shard
|
||||||
|
|
||||||
|
shardMu sync.Mutex
|
||||||
|
shards map[string]*kinesis.Shard
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *broker) shardLoop(ctx context.Context) {
|
||||||
|
b.fetchShards()
|
||||||
|
|
||||||
|
// add ticker, and cancellation
|
||||||
|
// also add signal to re-pull?
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(30 * time.Second):
|
||||||
|
b.fetchShards()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *broker) fetchShards() {
|
||||||
|
shards, err := b.listShards()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, shard := range shards {
|
||||||
|
if b.takeLease(shard) {
|
||||||
|
b.shardc <- shard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *broker) listShards() ([]*kinesis.Shard, error) {
|
||||||
|
var ss []*kinesis.Shard
|
||||||
|
var listShardsInput = &kinesis.ListShardsInput{
|
||||||
|
StreamName: aws.String(b.streamName),
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
resp, err := b.client.ListShards(listShardsInput)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ListShards error: %v", err)
|
||||||
|
}
|
||||||
|
ss = append(ss, resp.Shards...)
|
||||||
|
|
||||||
|
if resp.NextToken == nil {
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
listShardsInput = &kinesis.ListShardsInput{
|
||||||
|
NextToken: resp.NextToken,
|
||||||
|
StreamName: aws.String(b.streamName),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *broker) takeLease(shard *kinesis.Shard) bool {
|
||||||
|
b.shardMu.Lock()
|
||||||
|
defer b.shardMu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := b.shards[*shard.ShardId]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
b.shards[*shard.ShardId] = shard
|
||||||
|
return true
|
||||||
|
}
|
||||||
80
consumer.go
80
consumer.go
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
|
|
@ -70,7 +69,7 @@ type Consumer struct {
|
||||||
// function returns the special value SkipCheckpoint.
|
// function returns the special value SkipCheckpoint.
|
||||||
type ScanFunc func(*Record) error
|
type ScanFunc func(*Record) error
|
||||||
|
|
||||||
// SkipCheckpoint is used as a return value from ScanFuncs to indicate that
|
// SkipCheckpoint is used as a return value from ScanFunc to indicate that
|
||||||
// the current checkpoint should be skipped skipped. It is not returned
|
// the current checkpoint should be skipped skipped. It is not returned
|
||||||
// as an error by any function.
|
// as an error by any function.
|
||||||
var SkipCheckpoint = errors.New("skip checkpoint")
|
var SkipCheckpoint = errors.New("skip checkpoint")
|
||||||
|
|
@ -79,51 +78,45 @@ var SkipCheckpoint = errors.New("skip checkpoint")
|
||||||
// is passed through to each of the goroutines and called with each message pulled from
|
// is passed through to each of the goroutines and called with each message pulled from
|
||||||
// the stream.
|
// the stream.
|
||||||
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
|
||||||
|
var (
|
||||||
|
errc = make(chan error, 1)
|
||||||
|
shardc = make(chan *kinesis.Shard, 1)
|
||||||
|
broker = newBroker(c.client, c.streamName, shardc)
|
||||||
|
)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// get shard ids
|
go func() {
|
||||||
shardIDs, err := c.getShardIDs(c.streamName)
|
broker.shardLoop(ctx)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get shards error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(shardIDs) == 0 {
|
<-ctx.Done()
|
||||||
return fmt.Errorf("no shards available")
|
close(shardc)
|
||||||
}
|
}()
|
||||||
|
|
||||||
var (
|
// process each of the shards
|
||||||
wg sync.WaitGroup
|
for shard := range shardc {
|
||||||
errc = make(chan error, 1)
|
|
||||||
)
|
|
||||||
wg.Add(len(shardIDs))
|
|
||||||
|
|
||||||
// process each shard in a separate goroutine
|
|
||||||
for _, shardID := range shardIDs {
|
|
||||||
go func(shardID string) {
|
go func(shardID string) {
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
if err := c.ScanShard(ctx, shardID, fn); err != nil {
|
if err := c.ScanShard(ctx, shardID, fn); err != nil {
|
||||||
cancel()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
|
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
|
||||||
// first error to occur
|
// first error to occur
|
||||||
|
cancel()
|
||||||
default:
|
default:
|
||||||
// error has already occured
|
// error has already occured
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}(shardID)
|
}(aws.StringValue(shard.ShardId))
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
close(errc)
|
close(errc)
|
||||||
|
|
||||||
return <-errc
|
return <-errc
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanShard loops over records on a specific shard, calls the ScanFunc callback
|
// ScanShard loops over records on a specific shard, calls the callback func
|
||||||
// func for each record and checkpoints the progress of scan.
|
// for each record and checkpoints the progress of scan.
|
||||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
|
||||||
// get last seq number from checkpoint
|
// get last seq number from checkpoint
|
||||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
||||||
|
|
@ -137,7 +130,10 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
return fmt.Errorf("get shard iterator error: %v", err)
|
return fmt.Errorf("get shard iterator error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Log("scanning", shardID, lastSeqNum)
|
c.logger.Log("[START]\t", shardID, lastSeqNum)
|
||||||
|
defer func() {
|
||||||
|
c.logger.Log("[STOP]\t", shardID)
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
|
@ -148,8 +144,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
ShardIterator: shardIterator,
|
ShardIterator: shardIterator,
|
||||||
})
|
})
|
||||||
|
|
||||||
// often we can recover from GetRecords error by getting a
|
// attempt to recover from GetRecords error by getting new shard iterator
|
||||||
// new shard iterator, else return error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
|
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -181,6 +176,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
|
||||||
}
|
}
|
||||||
|
|
||||||
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
if isShardClosed(resp.NextShardIterator, shardIterator) {
|
||||||
|
c.logger.Log("[CLOSED]\t", shardID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -193,32 +189,6 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
|
||||||
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
return nextShardIterator == nil || currentShardIterator == nextShardIterator
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
|
|
||||||
var ss []string
|
|
||||||
var listShardsInput = &kinesis.ListShardsInput{
|
|
||||||
StreamName: aws.String(streamName),
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
resp, err := c.client.ListShards(listShardsInput)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("ListShards error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, shard := range resp.Shards {
|
|
||||||
ss = append(ss, *shard.ShardId)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.NextToken == nil {
|
|
||||||
return ss, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
listShardsInput = &kinesis.ListShardsInput{
|
|
||||||
NextToken: resp.NextToken,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
|
func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
|
||||||
params := &kinesis.GetShardIteratorInput{
|
params := &kinesis.GetShardIteratorInput{
|
||||||
ShardId: aws.String(shardID),
|
ShardId: aws.String(shardID),
|
||||||
|
|
|
||||||
101
consumer_test.go
101
consumer_test.go
|
|
@ -63,24 +63,27 @@ func TestScan(t *testing.T) {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var resultData string
|
var (
|
||||||
var fnCallCounter int
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
res string
|
||||||
|
)
|
||||||
|
|
||||||
var fn = func(r *Record) error {
|
var fn = func(r *Record) error {
|
||||||
fnCallCounter++
|
res += string(r.Data)
|
||||||
resultData += string(r.Data)
|
|
||||||
|
if string(r.Data) == "lastData" {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Scan(context.Background(), fn); err != nil {
|
if err := c.Scan(ctx, fn); err != nil {
|
||||||
t.Errorf("scan shard error expected nil. got %v", err)
|
t.Errorf("scan returned unexpected error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resultData != "firstDatalastData" {
|
if res != "firstDatalastData" {
|
||||||
t.Errorf("callback error expected %s, got %s", "FirstLast", resultData)
|
t.Errorf("callback error expected %s, got %s", "firstDatalastData", res)
|
||||||
}
|
|
||||||
|
|
||||||
if fnCallCounter != 2 {
|
|
||||||
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if val := ctr.counter; val != 2 {
|
if val := ctr.counter; val != 2 {
|
||||||
|
|
@ -146,15 +149,23 @@ func TestScanShard(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// callback fn appends record data
|
// callback fn appends record data
|
||||||
var res string
|
var (
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
res string
|
||||||
|
)
|
||||||
|
|
||||||
var fn = func(r *Record) error {
|
var fn = func(r *Record) error {
|
||||||
res += string(r.Data)
|
res += string(r.Data)
|
||||||
|
|
||||||
|
if string(r.Data) == "lastData" {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// scan shard
|
if err := c.Scan(ctx, fn); err != nil {
|
||||||
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
|
t.Errorf("scan returned unexpected error %v", err)
|
||||||
t.Fatalf("scan shard error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// runs callback func
|
// runs callback func
|
||||||
|
|
@ -236,14 +247,18 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
var fn = func(r *Record) error {
|
var fn = func(r *Record) error {
|
||||||
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
|
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
|
||||||
|
cancel()
|
||||||
return SkipCheckpoint
|
return SkipCheckpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.ScanShard(context.Background(), "myShard", fn)
|
err = c.ScanShard(ctx, "myShard", fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("scan shard error: %v", err)
|
t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -254,35 +269,35 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanShard_ShardIsClosed(t *testing.T) {
|
// func TestScanShard_ShardIsClosed(t *testing.T) {
|
||||||
var client = &kinesisClientMock{
|
// var client = &kinesisClientMock{
|
||||||
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
// getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
|
||||||
return &kinesis.GetShardIteratorOutput{
|
// return &kinesis.GetShardIteratorOutput{
|
||||||
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
// ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
|
||||||
}, nil
|
// }, nil
|
||||||
},
|
// },
|
||||||
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
// getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
|
||||||
return &kinesis.GetRecordsOutput{
|
// return &kinesis.GetRecordsOutput{
|
||||||
NextShardIterator: nil,
|
// NextShardIterator: nil,
|
||||||
Records: make([]*Record, 0),
|
// Records: make([]*Record, 0),
|
||||||
}, nil
|
// }, nil
|
||||||
},
|
// },
|
||||||
}
|
// }
|
||||||
|
|
||||||
c, err := New("myStreamName", WithClient(client))
|
// c, err := New("myStreamName", WithClient(client))
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Fatalf("new consumer error: %v", err)
|
// t.Fatalf("new consumer error: %v", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
var fn = func(r *Record) error {
|
// var fn = func(r *Record) error {
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
err = c.ScanShard(context.Background(), "myShard", fn)
|
// err = c.ScanShard(context.Background(), "myShard", fn)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Fatalf("scan shard error: %v", err)
|
// t.Fatalf("scan shard error: %v", err)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
type kinesisClientMock struct {
|
type kinesisClientMock struct {
|
||||||
kinesisiface.KinesisAPI
|
kinesisiface.KinesisAPI
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,16 @@ import (
|
||||||
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
|
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A myLogger provides a minimalistic logger satisfying the Logger interface.
|
||||||
|
type myLogger struct {
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log logs the parameters to the stdlib logger. See log.Println.
|
||||||
|
func (l *myLogger) Log(args ...interface{}) {
|
||||||
|
l.logger.Println(args...)
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
var (
|
||||||
app = flag.String("app", "", "Consumer app name")
|
app = flag.String("app", "", "Consumer app name")
|
||||||
|
|
@ -25,9 +35,16 @@ func main() {
|
||||||
log.Fatalf("checkpoint error: %v", err)
|
log.Fatalf("checkpoint error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// logger
|
||||||
|
logger := &myLogger{
|
||||||
|
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags),
|
||||||
|
}
|
||||||
|
|
||||||
// consumer
|
// consumer
|
||||||
c, err := consumer.New(
|
c, err := consumer.New(
|
||||||
*stream, consumer.WithCheckpoint(ck),
|
*stream,
|
||||||
|
consumer.WithCheckpoint(ck),
|
||||||
|
consumer.WithLogger(logger),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("consumer error: %v", err)
|
log.Fatalf("consumer error: %v", err)
|
||||||
|
|
@ -42,6 +59,7 @@ func main() {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-signals
|
<-signals
|
||||||
|
fmt.Println("caught exit signal, cancelling context!")
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue