From 97fe4e66fff02edfb239358d493770a261168407 Mon Sep 17 00:00:00 2001 From: Harlow Ward Date: Tue, 9 Apr 2019 22:03:12 -0700 Subject: [PATCH] Use shard broker to monitor and process new shards (#85) * Use shard broker to start processing new shards The addition of a shard broker will allow the consumer to be notified when new shards are added to the stream so it can consume them. Fixes: https://github.com/harlow/kinesis-consumer/issues/36 --- README.md | 8 +- broker.go | 114 +++++++++++++++++++++++++++ consumer.go | 79 ++++++------------- consumer_test.go | 87 ++++++++++---------- examples/consumer/cp-redis/README.md | 5 +- examples/consumer/cp-redis/main.go | 20 ++++- examples/producer/README.md | 3 +- 7 files changed, 210 insertions(+), 106 deletions(-) create mode 100644 broker.go diff --git a/README.md b/README.md index 4ae70fa..bfbae2c 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,9 @@ Get the package source: The consumer leverages a handler func that accepts a Kinesis record. The `Scan` method will consume all shards concurrently and call the callback func as it receives records from the stream. -_Important: The default Log, Counter, and Checkpoint are no-op which means no logs, counts, or checkpoints will be emitted when scanning the stream. See the options below to override these defaults._ +_Important 1: The `Scan` func will also poll the stream to check for new shards, it will automatcially start consuming new shards added to the stream._ + +_Important 2: The default Log, Counter, and Checkpoint are no-op which means no logs, counts, or checkpoints will be emitted when scanning the stream. See the options below to override these defaults._ ```go import( @@ -255,8 +257,8 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom ```go // logger -log := &myLogger{ - logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags) +logger := &myLogger{ + logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags), } // consumer diff --git a/broker.go b/broker.go new file mode 100644 index 0000000..ecf25a1 --- /dev/null +++ b/broker.go @@ -0,0 +1,114 @@ +package consumer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" +) + +func newBroker( + client kinesisiface.KinesisAPI, + streamName string, + shardc chan *kinesis.Shard, + logger Logger, +) *broker { + return &broker{ + client: client, + shards: make(map[string]*kinesis.Shard), + streamName: streamName, + shardc: shardc, + logger: logger, + } +} + +// broker caches a local list of the shards we are already processing +// and routinely polls the stream looking for new shards to process +type broker struct { + client kinesisiface.KinesisAPI + streamName string + shardc chan *kinesis.Shard + logger Logger + + shardMu sync.Mutex + shards map[string]*kinesis.Shard +} + +// start is a blocking operation which will loop and attempt to find new +// shards on a regular cadence. +func (b *broker) start(ctx context.Context) { + b.findNewShards() + ticker := time.NewTicker(30 * time.Second) + + // Note: while ticker is a rather naive approach to this problem, + // it actually simplies a few things. i.e. If we miss a new shard while + // AWS is resharding we'll pick it up max 30 seconds later. + + // It might be worth refactoring this flow to allow the consumer to + // to notify the broker when a shard is closed. However, shards don't + // necessarily close at the same time, so we could potentially get a + // thundering heard of notifications from the consumer. + + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + b.findNewShards() + } + } +} + +// findNewShards pulls the list of shards from the Kinesis API +// and uses a local cache to determine if we are already processing +// a particular shard. +func (b *broker) findNewShards() { + b.shardMu.Lock() + defer b.shardMu.Unlock() + + b.logger.Log("[BROKER]", "fetching shards") + + shards, err := b.listShards() + if err != nil { + b.logger.Log("[BROKER]", err) + return + } + + for _, shard := range shards { + if _, ok := b.shards[*shard.ShardId]; ok { + continue + } + b.shards[*shard.ShardId] = shard + b.shardc <- shard + } +} + +// listShards pulls a list of shard IDs from the kinesis api +func (b *broker) listShards() ([]*kinesis.Shard, error) { + var ss []*kinesis.Shard + var listShardsInput = &kinesis.ListShardsInput{ + StreamName: aws.String(b.streamName), + } + + for { + resp, err := b.client.ListShards(listShardsInput) + if err != nil { + return nil, fmt.Errorf("ListShards error: %v", err) + } + ss = append(ss, resp.Shards...) + + if resp.NextToken == nil { + return ss, nil + } + + listShardsInput = &kinesis.ListShardsInput{ + NextToken: resp.NextToken, + StreamName: aws.String(b.streamName), + } + } +} diff --git a/consumer.go b/consumer.go index a60d378..adbb000 100644 --- a/consumer.go +++ b/consumer.go @@ -6,7 +6,6 @@ import ( "fmt" "io/ioutil" "log" - "sync" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -70,7 +69,7 @@ type Consumer struct { // function returns the special value SkipCheckpoint. type ScanFunc func(*Record) error -// SkipCheckpoint is used as a return value from ScanFuncs to indicate that +// SkipCheckpoint is used as a return value from ScanFunc to indicate that // the current checkpoint should be skipped skipped. It is not returned // as an error by any function. var SkipCheckpoint = errors.New("skip checkpoint") @@ -79,51 +78,44 @@ var SkipCheckpoint = errors.New("skip checkpoint") // is passed through to each of the goroutines and called with each message pulled from // the stream. func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { + var ( + errc = make(chan error, 1) + shardc = make(chan *kinesis.Shard, 1) + broker = newBroker(c.client, c.streamName, shardc, c.logger) + ) + ctx, cancel := context.WithCancel(ctx) defer cancel() - // get shard ids - shardIDs, err := c.getShardIDs(c.streamName) - if err != nil { - return fmt.Errorf("get shards error: %v", err) - } + go broker.start(ctx) - if len(shardIDs) == 0 { - return fmt.Errorf("no shards available") - } + go func() { + <-ctx.Done() + close(shardc) + }() - var ( - wg sync.WaitGroup - errc = make(chan error, 1) - ) - wg.Add(len(shardIDs)) - - // process each shard in a separate goroutine - for _, shardID := range shardIDs { + // process each of the shards + for shard := range shardc { go func(shardID string) { - defer wg.Done() - if err := c.ScanShard(ctx, shardID, fn); err != nil { - cancel() - select { case errc <- fmt.Errorf("shard %s error: %v", shardID, err): // first error to occur + cancel() default: // error has already occured } } - }(shardID) + }(aws.StringValue(shard.ShardId)) } - wg.Wait() close(errc) return <-errc } -// ScanShard loops over records on a specific shard, calls the ScanFunc callback -// func for each record and checkpoints the progress of scan. +// ScanShard loops over records on a specific shard, calls the callback func +// for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { // get last seq number from checkpoint lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) @@ -137,7 +129,10 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return fmt.Errorf("get shard iterator error: %v", err) } - c.logger.Log("scanning", shardID, lastSeqNum) + c.logger.Log("[START]\t", shardID, lastSeqNum) + defer func() { + c.logger.Log("[STOP]\t", shardID) + }() for { select { @@ -148,8 +143,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e ShardIterator: shardIterator, }) - // often we can recover from GetRecords error by getting a - // new shard iterator, else return error + // attempt to recover from GetRecords error by getting new shard iterator if err != nil { shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) if err != nil { @@ -181,6 +175,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } if isShardClosed(resp.NextShardIterator, shardIterator) { + c.logger.Log("[CLOSED]\t", shardID) return nil } @@ -193,32 +188,6 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } -func (c *Consumer) getShardIDs(streamName string) ([]string, error) { - var ss []string - var listShardsInput = &kinesis.ListShardsInput{ - StreamName: aws.String(streamName), - } - - for { - resp, err := c.client.ListShards(listShardsInput) - if err != nil { - return nil, fmt.Errorf("ListShards error: %v", err) - } - - for _, shard := range resp.Shards { - ss = append(ss, *shard.ShardId) - } - - if resp.NextToken == nil { - return ss, nil - } - - listShardsInput = &kinesis.ListShardsInput{ - NextToken: resp.NextToken, - } - } -} - func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) { params := &kinesis.GetShardIteratorInput{ ShardId: aws.String(shardID), diff --git a/consumer_test.go b/consumer_test.go index 4712e3c..caaa04b 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -63,27 +63,30 @@ func TestScan(t *testing.T) { t.Fatalf("new consumer error: %v", err) } - var resultData string - var fnCallCounter int + var ( + ctx, cancel = context.WithCancel(context.Background()) + res string + ) + var fn = func(r *Record) error { - fnCallCounter++ - resultData += string(r.Data) + res += string(r.Data) + + if string(r.Data) == "lastData" { + cancel() + } + return nil } - if err := c.Scan(context.Background(), fn); err != nil { - t.Errorf("scan shard error expected nil. got %v", err) + if err := c.Scan(ctx, fn); err != nil { + t.Errorf("scan returned unexpected error %v", err) } - if resultData != "firstDatalastData" { - t.Errorf("callback error expected %s, got %s", "FirstLast", resultData) + if res != "firstDatalastData" { + t.Errorf("callback error expected %s, got %s", "firstDatalastData", res) } - if fnCallCounter != 2 { - t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter) - } - - if val := ctr.counter; val != 2 { + if val := ctr.Get(); val != 2 { t.Errorf("counter error expected %d, got %d", 2, val) } @@ -93,29 +96,6 @@ func TestScan(t *testing.T) { } } -func TestScan_NoShardsAvailable(t *testing.T) { - client := &kinesisClientMock{ - listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) { - return &kinesis.ListShardsOutput{ - Shards: make([]*kinesis.Shard, 0), - }, nil - }, - } - - var fn = func(r *Record) error { - return nil - } - - c, err := New("myStreamName", WithClient(client)) - if err != nil { - t.Fatalf("new consumer error: %v", err) - } - - if err := c.Scan(context.Background(), fn); err == nil { - t.Errorf("scan shard error expected not nil. got %v", err) - } -} - func TestScanShard(t *testing.T) { var client = &kinesisClientMock{ getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { @@ -146,15 +126,23 @@ func TestScanShard(t *testing.T) { } // callback fn appends record data - var res string + var ( + ctx, cancel = context.WithCancel(context.Background()) + res string + ) + var fn = func(r *Record) error { res += string(r.Data) + + if string(r.Data) == "lastData" { + cancel() + } + return nil } - // scan shard - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { - t.Fatalf("scan shard error: %v", err) + if err := c.ScanShard(ctx, "myShard", fn); err != nil { + t.Errorf("scan returned unexpected error %v", err) } // runs callback func @@ -163,7 +151,7 @@ func TestScanShard(t *testing.T) { } // increments counter - if val := ctr.counter; val != 2 { + if val := ctr.Get(); val != 2 { t.Fatalf("counter error expected %d, got %d", 2, val) } @@ -236,14 +224,18 @@ func TestScanShard_SkipCheckpoint(t *testing.T) { t.Fatalf("new consumer error: %v", err) } + var ctx, cancel = context.WithCancel(context.Background()) + var fn = func(r *Record) error { if aws.StringValue(r.SequenceNumber) == "lastSeqNum" { + cancel() return SkipCheckpoint } + return nil } - err = c.ScanShard(context.Background(), "myShard", fn) + err = c.ScanShard(ctx, "myShard", fn) if err != nil { t.Fatalf("scan shard error: %v", err) } @@ -329,8 +321,19 @@ func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) { // implementation of counter type fakeCounter struct { counter int64 + mu sync.Mutex +} + +func (fc *fakeCounter) Get() int64 { + fc.mu.Lock() + defer fc.mu.Unlock() + + return fc.counter } func (fc *fakeCounter) Add(streamName string, count int64) { + fc.mu.Lock() + defer fc.mu.Unlock() + fc.counter += count } diff --git a/examples/consumer/cp-redis/README.md b/examples/consumer/cp-redis/README.md index a16d189..b6d30b3 100644 --- a/examples/consumer/cp-redis/README.md +++ b/examples/consumer/cp-redis/README.md @@ -7,12 +7,11 @@ Read records from the Kinesis stream Export the required environment vars for connecting to the Kinesis stream and Redis for checkpoint: ``` -export AWS_ACCESS_KEY= +export AWS_PROFILE= export AWS_REGION= -export AWS_SECRET_KEY= export REDIS_URL= ``` ### Run the consumer - $ go run main.go --app appName --stream streamName + $ go run main.go --app appName --stream streamName diff --git a/examples/consumer/cp-redis/main.go b/examples/consumer/cp-redis/main.go index f199396..8b3450e 100644 --- a/examples/consumer/cp-redis/main.go +++ b/examples/consumer/cp-redis/main.go @@ -12,6 +12,16 @@ import ( checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis" ) +// A myLogger provides a minimalistic logger satisfying the Logger interface. +type myLogger struct { + logger *log.Logger +} + +// Log logs the parameters to the stdlib logger. See log.Println. +func (l *myLogger) Log(args ...interface{}) { + l.logger.Println(args...) +} + func main() { var ( app = flag.String("app", "", "Consumer app name") @@ -25,9 +35,16 @@ func main() { log.Fatalf("checkpoint error: %v", err) } + // logger + logger := &myLogger{ + logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags), + } + // consumer c, err := consumer.New( - *stream, consumer.WithCheckpoint(ck), + *stream, + consumer.WithCheckpoint(ck), + consumer.WithLogger(logger), ) if err != nil { log.Fatalf("consumer error: %v", err) @@ -42,6 +59,7 @@ func main() { go func() { <-signals + fmt.Println("caught exit signal, cancelling context!") cancel() }() diff --git a/examples/producer/README.md b/examples/producer/README.md index d89b959..a620e95 100644 --- a/examples/producer/README.md +++ b/examples/producer/README.md @@ -7,9 +7,8 @@ A prepopulated file with JSON users is available on S3 for seeing the stream. Export the required environment vars for connecting to the Kinesis stream: ``` -export AWS_ACCESS_KEY= +export AWS_PROFILE= export AWS_REGION_NAME= -export AWS_SECRET_KEY= ``` ### Running the code