Merge remote-tracking branch 'origin/master'

This commit is contained in:
Edward Tsang 2018-09-10 12:24:22 -07:00
commit 37fc2cd212
19 changed files with 961 additions and 366 deletions

13
.gitignore vendored
View file

@ -3,12 +3,6 @@
*.a
*.so
# Environment vars
.env
# Seed data
users.txt
# Folders
_obj
_test
@ -42,3 +36,10 @@ vendor/**
# Benchmark files
prof.cpu
prof.mem
# VSCode files
/.vscode
/**/debug
# Goland files
.idea/

View file

@ -4,25 +4,26 @@ All notable changes to this project will be documented in this file.
## [Unreleased (`master`)][unreleased]
** Breaking changes to consumer library **
Major changes:
* Use [functional options][options] for config
* Remove intermediate batching of kinesis records
* Call the callback func with each record
* Use dep for vendoring dependencies
* Add DDB as storage layer for checkpoints
* Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client.
* Rename `ScanError` to `ScanStatus` as it's not always an error.
Minor changes:
* remove unused buffer and emitter code
* Update tests to use Kinesis mock
[unreleased]: https://github.com/harlow/kinesis-consumer/compare/v0.1.0...HEAD
[options]: https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
## v0.2.0 - 2018-07-28
This is the last stable release from which there is a separate Client. It has caused confusion and will be removed going forward.
https://github.com/harlow/kinesis-consumer/releases/tag/v0.2.0
## v0.1.0 - 2017-11-20
This is the last stable release of the consumer which aggregated records in `batch` before calling the callback func.
This is the last stable release of the consumer which aggregated records in `batch` before calling the callback func.
https://github.com/harlow/kinesis-consumer/releases/tag/v0.1.0
[unreleased]: https://github.com/harlow/kinesis-consumer/compare/v0.2.0...HEAD
[options]: https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis

21
Gopkg.lock generated
View file

@ -23,6 +23,7 @@
"aws/credentials/ec2rolecreds",
"aws/credentials/endpointcreds",
"aws/credentials/stscreds",
"aws/csm",
"aws/defaults",
"aws/ec2metadata",
"aws/endpoints",
@ -31,6 +32,7 @@
"aws/signer/v4",
"internal/sdkio",
"internal/sdkrand",
"internal/sdkuri",
"internal/shareddefaults",
"private/protocol",
"private/protocol/json/jsonutil",
@ -46,20 +48,29 @@
"service/kinesis/kinesisiface",
"service/sts"
]
revision = "827e7eac8c2680d5bdea7bc3ef29c596eabe1eae"
version = "v1.13.59"
revision = "8475c414b1bd58b8cc214873a8854e3a621e67d7"
version = "v1.15.0"
[[projects]]
name = "github.com/go-ini/ini"
packages = ["."]
revision = "7e7da451323b6766da368f8a1e8ec9a88a16b4a0"
version = "v1.31.1"
revision = "358ee7663966325963d4e8b2e1fbd570c5195153"
version = "v1.38.1"
[[projects]]
name = "github.com/jmespath/go-jmespath"
packages = ["."]
revision = "0b12d6b5"
[[projects]]
branch = "master"
name = "github.com/lib/pq"
packages = [
".",
"oid"
]
revision = "90697d60dd844d5ef6ff15135d0203f65d2f53b8"
[[projects]]
name = "github.com/pkg/errors"
packages = ["."]
@ -82,6 +93,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "1b40486b645b81bc2215d0153631e9e002e534ba86713ba55500ce62c07cbad8"
inputs-digest = "6b3044ce1b075f919471f2457f32450efaa36518381fd84164641860c296de5a"
solver-name = "gps-cdcl"
solver-version = 1

View file

@ -1,4 +1,3 @@
# Gopkg.toml example
#
# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md
@ -17,8 +16,13 @@
# source = "github.com/myfork/project2"
#
# [[override]]
# name = "github.com/x/y"
# version = "2.4.0"
# name = "github.com/x/y"
# version = "2.4.0"
#
# [prune]
# non-go = false
# go-tests = true
# unused-packages = true
[[constraint]]
@ -27,24 +31,16 @@
[[constraint]]
name = "github.com/aws/aws-sdk-go"
version = "1.12.30"
[[constraint]]
branch = "master"
name = "github.com/bmizerany/assert"
[[constraint]]
branch = "master"
name = "github.com/crowdmob/goamz"
version = "1.15.0"
[[constraint]]
branch = "master"
name = "github.com/lib/pq"
[[constraint]]
branch = "master"
name = "github.com/tj/go-kinesis"
[[constraint]]
name = "gopkg.in/redis.v5"
version = "5.2.9"
[prune]
go-tests = true
unused-packages = true

106
README.md
View file

@ -37,14 +37,14 @@ func main() {
log.Fatalf("consumer error: %v", err)
}
// start
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanError {
// start scan
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanError{
StopScan: false, // true to stop scan
SkipCheckpoint: false, // true to skip checkpoint
}
return consumer.ScanStatus{
StopScan: false, // true to stop scan
SkipCheckpoint: false, // true to skip checkpoint
}
})
if err != nil {
log.Fatalf("scan error: %v", err)
@ -55,6 +55,24 @@ func main() {
}
```
## Scan status
The scan func returns a `consumer.ScanStatus` the struct allows some basic flow control.
```go
// continue scanning
return consumer.ScanStatus{}
// continue scanning, skip saving checkpoint
return consumer.ScanStatus{SkipCheckpoint: true}
// stop scanning, return nil
return consumer.ScanStatus{StopScan: true}
// stop scanning, return error
return consumer.ScanStatus{Error: err}
```
## Checkpoint
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback.
@ -107,8 +125,9 @@ myDynamoDbClient := dynamodb.New(session.New(aws.NewConfig()))
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient))
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
log.Fatalf("new checkpoint error: %v", err)
}
// Or we can provide your own Retryer to customize what triggers a retry inside checkpoint
// See code in examples
// ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{}))
@ -123,11 +142,39 @@ Sort key: shard_id
<img width="727" alt="screen shot 2017-11-22 at 7 59 36 pm" src="https://user-images.githubusercontent.com/739782/33158557-b90e4228-cfbf-11e7-9a99-73b56a446f5f.png">
### Postgres Checkpoint
The Postgres checkpoint requires Table Name, App Name, Stream Name and ConnectionString:
```go
import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/postgres"
// postgres checkpoint
ck, err := checkpoint.New(app, table, connStr)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
}
```
To leverage the Postgres checkpoint we'll also need to create a table:
```sql
CREATE TABLE kinesis_consumer (
namespace text NOT NULL,
shard_id text NOT NULL,
sequence_number numeric NOT NULL,
CONSTRAINT kinesis_consumer_pk PRIMARY KEY (namespace, shard_id)
);
```
The table name has to be the same that you specify when creating the checkpoint. The primary key composed by namespace and shard_id is mandatory in order to the checkpoint run without issues and also to ensure data integrity.
## Options
The consumer allows the following optional overrides.
### Client
### Kinesis Client
Override the Kinesis client if there is any special config needed:
@ -162,16 +209,55 @@ The [expvar package](https://golang.org/pkg/expvar/) will display consumer count
### Logging
Logging supports the basic built-in logging library or use thrid party external one, so long as
it implements the Logger interface.
For example, to use the builtin logging package, we wrap it with myLogger structure.
```
// A myLogger provides a minimalistic logger satisfying the Logger interface.
type myLogger struct {
logger *log.Logger
}
// Log logs the parameters to the stdlib logger. See log.Println.
func (l *myLogger) Log(args ...interface{}) {
l.logger.Println(args...)
}
```
The package defaults to `ioutil.Discard` so swallow all logs. This can be customized with the preferred logging strategy:
```go
// logger
logger := log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
log := &myLogger{
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
}
// consumer
c, err := consumer.New(streamName, consumer.WithLogger(logger))
```
To use a more complicated logging library, e.g. apex log
```
type myLogger struct {
logger *log.Logger
}
func (l *myLogger) Log(args ...interface{}) {
l.logger.Infof("producer", args...)
}
func main() {
log := &myLogger{
logger: alog.Logger{
Handler: text.New(os.Stderr),
Level: alog.DebugLevel,
},
}
```
## Contributing
Please see [CONTRIBUTING.md] for more information. Thank you, [contributors]!

13
checkpoint.go Normal file
View file

@ -0,0 +1,13 @@
package consumer
// Checkpoint interface used track consumer progress in the stream
type Checkpoint interface {
Get(streamName, shardID string) (string, error)
Set(streamName, shardID, sequenceNumber string) error
}
// noopCheckpoint implements the checkpoint interface with discard
type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string, string) error { return nil }
func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }

View file

@ -0,0 +1,161 @@
package postgres
import (
"database/sql"
"errors"
"fmt"
"sync"
"time"
// this is the postgres package so it makes sense to be here
_ "github.com/lib/pq"
)
var getCheckpointQuery = `SELECT sequence_number
FROM %s
WHERE namespace=$1 AND shard_id=$2`
var upsertCheckpoint = `INSERT INTO %s (namespace, shard_id, sequence_number)
VALUES($1, $2, $3)
ON CONFLICT (namespace, shard_id)
DO
UPDATE
SET sequence_number= $3`
type key struct {
streamName string
shardID string
}
// Option is used to override defaults when creating a new Checkpoint
type Option func(*Checkpoint)
// WithMaxInterval sets the flush interval
func WithMaxInterval(maxInterval time.Duration) Option {
return func(c *Checkpoint) {
c.maxInterval = maxInterval
}
}
// Checkpoint stores and retreives the last evaluated key from a DDB scan
type Checkpoint struct {
appName string
conn *sql.DB
mu *sync.Mutex // protects the checkpoints
done chan struct{}
checkpoints map[key]string
maxInterval time.Duration
}
// New returns a checkpoint that uses PostgresDB for underlying storage
// Using connectionStr turn it more flexible to use specific db configs
func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, error) {
if tableName == "" {
return nil, errors.New("Table name not defined")
}
conn, err := sql.Open("postgres", connectionStr)
if err != nil {
return nil, err
}
getCheckpointQuery = fmt.Sprintf(getCheckpointQuery, tableName)
upsertCheckpoint = fmt.Sprintf(upsertCheckpoint, tableName)
ck := &Checkpoint{
conn: conn,
appName: appName,
done: make(chan struct{}),
maxInterval: time.Duration(1 * time.Minute),
mu: new(sync.Mutex),
checkpoints: map[key]string{},
}
for _, opt := range opts {
opt(ck)
}
go ck.loop()
return ck, nil
}
// Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)
var sequenceNumber string
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)
if err != nil {
if err == sql.ErrNoRows {
return "", nil
}
return "", err
}
return sequenceNumber, nil
}
// Set stores a checkpoint for a shard (e.g. sequence number of last record processed by application).
// Upon failover, record processing is resumed from this point.
func (c *Checkpoint) Set(streamName, shardID, sequenceNumber string) error {
c.mu.Lock()
defer c.mu.Unlock()
if sequenceNumber == "" {
return fmt.Errorf("sequence number should not be empty")
}
key := key{
streamName: streamName,
shardID: shardID,
}
c.checkpoints[key] = sequenceNumber
return nil
}
// Shutdown the checkpoint. Save any in-flight data.
func (c *Checkpoint) Shutdown() error {
defer c.conn.Close()
c.done <- struct{}{}
return c.save()
}
func (c *Checkpoint) loop() {
tick := time.NewTicker(c.maxInterval)
defer tick.Stop()
defer close(c.done)
for {
select {
case <-tick.C:
c.save()
case <-c.done:
return
}
}
}
func (c *Checkpoint) save() error {
c.mu.Lock()
defer c.mu.Unlock()
for key, sequenceNumber := range c.checkpoints {
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
return err
}
}
return nil
}

143
client.go
View file

@ -1,143 +0,0 @@
package consumer
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
// ClientOption is used to override defaults when creating a KinesisClient
type ClientOption func(*KinesisClient)
// WithKinesis overrides the default Kinesis client
func WithKinesis(svc kinesisiface.KinesisAPI) ClientOption {
return func(kc *KinesisClient) {
kc.svc = svc
}
}
// NewKinesisClient returns client to interface with Kinesis stream
func NewKinesisClient(opts ...ClientOption) *KinesisClient {
kc := &KinesisClient{}
for _, opt := range opts {
opt(kc)
}
if kc.svc == nil {
kc.svc = kinesis.New(session.New(aws.NewConfig()))
}
return kc
}
// KinesisClient acts as wrapper around Kinesis client
type KinesisClient struct {
svc kinesisiface.KinesisAPI
}
// GetShardIDs returns shard ids in a given stream
func (c *KinesisClient) GetShardIDs(streamName string) ([]string, error) {
resp, err := c.svc.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
ss := []string{}
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}
return ss, nil
}
// GetRecords returns a chan Record from a Shard of the Stream
func (c *KinesisClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
shardIterator, err := c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
return nil, nil, fmt.Errorf("get shard iterator error: %v", err)
}
var (
recc = make(chan *Record, 10000)
errc = make(chan error, 1)
)
go func() {
defer func() {
close(recc)
close(errc)
}()
for {
select {
case <-ctx.Done():
return
default:
resp, err := c.svc.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil {
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
continue
}
for _, r := range resp.Records {
select {
case <-ctx.Done():
return
case recc <- r:
lastSeqNum = *r.SequenceNumber
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
} else {
shardIterator = resp.NextShardIterator
}
}
}
}()
return recc, errc, nil
}
func (c *KinesisClient) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.svc.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
}

View file

@ -7,81 +7,54 @@ import (
"log"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
// ScanError signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanError struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// Record is an alias of record returned from kinesis library
type Record = kinesis.Record
// Client interface is used for interacting with kinesis stream
type Client interface {
GetShardIDs(string) ([]string, error)
GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error)
}
// Counter interface is used for exposing basic metrics from the scanner
type Counter interface {
Add(string, int64)
}
type noopCounter struct{}
func (n noopCounter) Add(string, int64) {}
// Checkpoint interface used track consumer progress in the stream
type Checkpoint interface {
Get(streamName, shardID string) (string, error)
Set(streamName, shardID, sequenceNumber string) error
}
type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string, string) error { return nil }
func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error
type Option func(*Consumer)
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint Checkpoint) Option {
return func(c *Consumer) error {
return func(c *Consumer) {
c.checkpoint = checkpoint
return nil
}
}
// WithLogger overrides the default logger
func WithLogger(logger *log.Logger) Option {
return func(c *Consumer) error {
func WithLogger(logger Logger) Option {
return func(c *Consumer) {
c.logger = logger
return nil
}
}
// WithCounter overrides the default counter
func WithCounter(counter Counter) Option {
return func(c *Consumer) error {
return func(c *Consumer) {
c.counter = counter
return nil
}
}
// WithClient overrides the default client
func WithClient(client Client) Option {
return func(c *Consumer) error {
func WithClient(client kinesisiface.KinesisAPI) Option {
return func(c *Consumer) {
c.client = client
return nil
}
}
// ScanStatus signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanStatus struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(streamName string, opts ...Option) (*Consumer, error) {
@ -94,15 +67,23 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
streamName: streamName,
checkpoint: &noopCheckpoint{},
counter: &noopCounter{},
logger: log.New(ioutil.Discard, "", log.LstdFlags),
client: NewKinesisClient(),
logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
},
}
// override defaults
for _, opt := range opts {
if err := opt(c); err != nil {
opt(c)
}
// default client if none provided
if c.client == nil {
newSession, err := session.NewSession(aws.NewConfig())
if err != nil {
return nil, err
}
c.client = kinesis.New(newSession)
}
return c, nil
@ -111,16 +92,20 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
streamName string
client Client
logger *log.Logger
client kinesisiface.KinesisAPI
logger Logger
checkpoint Checkpoint
counter Counter
}
// Scan scans each of the shards of the stream, calls the callback
// func with each of the kinesis records.
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error {
shardIDs, err := c.client.GetShardIDs(c.streamName)
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// get shard ids
shardIDs, err := c.getShardIDs(c.streamName)
if err != nil {
return fmt.Errorf("get shards error: %v", err)
}
@ -129,16 +114,13 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error {
return fmt.Errorf("no shards available")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var (
wg sync.WaitGroup
errc = make(chan error, 1)
)
wg.Add(len(shardIDs))
// process each shard in goroutine
// process each shard in a separate goroutine
for _, shardID := range shardIDs {
go func(shardID string) {
defer wg.Done()
@ -158,48 +140,130 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error {
wg.Wait()
close(errc)
return <-errc
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
// Note: Returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Record) ScanError) (err error) {
func (c *Consumer) ScanShard(
ctx context.Context,
shardID string,
fn func(*Record) ScanStatus,
) error {
// get checkpoint
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil {
return fmt.Errorf("get checkpoint error: %v", err)
}
c.logger.Println("scanning", shardID, lastSeqNum)
// get records
recc, errc, err := c.client.GetRecords(ctx, c.streamName, shardID, lastSeqNum)
// get shard iterator
shardIterator, err := c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get records error: %v", err)
}
// loop records
for r := range recc {
scanError := fn(r)
// It will be nicer if this can be reported with checkpoint error
err = scanError.Error
// Skip invalid state
if scanError.StopScan && scanError.SkipCheckpoint {
continue
}
if scanError.StopScan {
break
}
if !scanError.SkipCheckpoint {
c.counter.Add("records", 1)
err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber)
if err != nil {
return fmt.Errorf("set checkpoint error: %v", err)
}
}
return fmt.Errorf("get shard iterator error: %v", err)
}
c.logger.Println("exiting", shardID)
return <-errc
c.logger.Log("scanning", shardID, lastSeqNum)
return c.scanPagesOfShard(ctx, shardID, lastSeqNum, shardIterator, fn)
}
func (c *Consumer) scanPagesOfShard(ctx context.Context, shardID, lastSeqNum string, shardIterator *string, fn func(*Record) ScanStatus) error {
for {
select {
case <-ctx.Done():
return nil
default:
resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
})
if err != nil {
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
continue
}
// loop records of page
for _, r := range resp.Records {
isScanStopped, err := c.handleRecord(shardID, r, fn)
if err != nil {
return err
}
if isScanStopped {
return nil
}
lastSeqNum = *r.SequenceNumber
}
if isShardClosed(resp.NextShardIterator, shardIterator) {
return nil
}
shardIterator = resp.NextShardIterator
}
}
}
func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}
func (c *Consumer) handleRecord(shardID string, r *Record, fn func(*Record) ScanStatus) (isScanStopped bool, err error) {
status := fn(r)
if !status.SkipCheckpoint {
if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil {
return false, err
}
}
if err := status.Error; err != nil {
return false, err
}
c.counter.Add("records", 1)
if status.StopScan {
return true, nil
}
return false, nil
}
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
var ss []string
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}
return ss, nil
}
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.client.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
}

View file

@ -3,12 +3,12 @@ package consumer
import (
"context"
"fmt"
"io/ioutil"
"log"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
func TestNew(t *testing.T) {
@ -18,97 +18,292 @@ func TestNew(t *testing.T) {
}
}
func TestScanShard(t *testing.T) {
func TestConsumer_Scan(t *testing.T) {
records := []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
client := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return &kinesis.DescribeStreamOutput{
StreamDescription: &kinesis.StreamDescription{
Shards: []*kinesis.Shard{
{ShardId: aws.String("myShard")},
},
},
}, nil
},
}
var (
ckp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
client = newFakeClient(
&Record{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&Record{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
)
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c := &Consumer{
streamName: "myStreamName",
client: client,
checkpoint: ckp,
counter: ctr,
logger: log.New(ioutil.Discard, "", log.LstdFlags),
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithCheckpoint(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
// callback fn simply appends the record data to result string
var resultData string
var fnCallCounter int
var fn = func(r *Record) ScanStatus {
fnCallCounter++
resultData += string(r.Data)
return ScanStatus{}
}
if err := c.Scan(context.Background(), fn); err != nil {
t.Errorf("scan shard error expected nil. got %v", err)
}
if resultData != "firstDatalastData" {
t.Errorf("callback error expected %s, got %s", "firstDatalastData", resultData)
}
if fnCallCounter != 2 {
t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter)
}
if val := ctr.counter; val != 2 {
t.Errorf("counter error expected %d, got %d", 2, val)
}
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Errorf("checkout error expected %s, got %s", "lastSeqNum", val)
}
}
func TestConsumer_Scan_NoShardsAvailable(t *testing.T) {
client := &kinesisClientMock{
describeStreamMock: func(input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return &kinesis.DescribeStreamOutput{
StreamDescription: &kinesis.StreamDescription{
Shards: make([]*kinesis.Shard, 0),
},
}, nil
},
}
var (
resultData string
fn = func(r *Record) ScanError {
resultData += string(r.Data)
err := errors.New("some error happened")
return ScanError{
Error: err,
StopScan: false,
SkipCheckpoint: false,
}
}
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
// scan shard
err := c.ScanShard(context.Background(), "myShard", fn)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithCheckpoint(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
var fnCallCounter int
var fn = func(r *Record) ScanStatus {
fnCallCounter++
return ScanStatus{}
}
if err := c.Scan(context.Background(), fn); err == nil {
t.Errorf("scan shard error expected not nil. got %v", err)
}
if fnCallCounter != 0 {
t.Errorf("the callback function expects %v, got %v", 0, fnCallCounter)
}
if val := ctr.counter; val != 0 {
t.Errorf("counter error expected %d, got %d", 0, val)
}
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "" {
t.Errorf("checkout error expected %s, got %s", "", val)
}
}
func TestScanShard(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
}
var (
cp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
)
c, err := New("myStreamName",
WithClient(client),
WithCounter(ctr),
WithCheckpoint(cp),
)
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
// callback fn appends record data
var resultData string
var fn = func(r *Record) ScanStatus {
resultData += string(r.Data)
return ScanStatus{}
}
// scan shard
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
t.Fatalf("scan shard error: %v", err)
}
// runs callback func
if resultData != "firstDatalastData" {
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", resultData)
}
// increments counter
if val := ctr.counter; val != 2 {
t.Fatalf("counter error expected %d, got %d", 2, val)
}
// sets checkpoint
val, err := ckp.Get("myStreamName", "myShard")
val, err := cp.Get("myStreamName", "myShard")
if err != nil && val != "lastSeqNum" {
t.Fatalf("checkout error expected %s, got %s", "lastSeqNum", val)
}
}
// calls callback func
if resultData != "firstDatalastData" {
t.Fatalf("callback error expected %s, got %s", "firstDatalastData", val)
func TestScanShard_StopScan(t *testing.T) {
var records = []*kinesis.Record{
{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
}
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: records,
}, nil
},
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
// callback fn appends record data
var resultData string
var fn = func(r *Record) ScanStatus {
resultData += string(r.Data)
return ScanStatus{StopScan: true}
}
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
t.Fatalf("scan shard error: %v", err)
}
if resultData != "firstData" {
t.Fatalf("callback error expected %s, got %s", "firstData", resultData)
}
}
func newFakeClient(rs ...*Record) *fakeClient {
fc := &fakeClient{
recc: make(chan *Record, len(rs)),
errc: make(chan error),
func TestScanShard_ShardIsClosed(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*Record, 0),
}, nil
},
}
for _, r := range rs {
fc.recc <- r
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
close(fc.errc)
close(fc.recc)
var fn = func(r *Record) ScanStatus {
return ScanStatus{}
}
return fc
if err := c.ScanShard(context.Background(), "myShard", fn); err != nil {
t.Fatalf("scan shard error: %v", err)
}
}
type fakeClient struct {
shardIDs []string
recc chan *Record
errc chan error
type kinesisClientMock struct {
kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
describeStreamMock func(*kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error)
}
func (fc *fakeClient) GetShardIDs(string) ([]string, error) {
return fc.shardIDs, nil
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return c.getRecordsMock(in)
}
func (fc *fakeClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
return fc.recc, fc.errc, nil
func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return c.getShardIteratorMock(in)
}
func (c *kinesisClientMock) DescribeStream(in *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
return c.describeStreamMock(in)
}
// implementation of checkpoint
type fakeCheckpoint struct {
cache map[string]string
mu sync.Mutex
@ -131,6 +326,7 @@ func (fc *fakeCheckpoint) Get(streamName, shardID string) (string, error) {
return fc.cache[key], nil
}
// implementation of counter
type fakeCounter struct {
counter int64
}

11
counter.go Normal file
View file

@ -0,0 +1,11 @@
package consumer
// Counter interface is used for exposing basic metrics from the scanner
type Counter interface {
Add(string, int64)
}
// noopCounter implements counter interface with discard
type noopCounter struct{}
func (n noopCounter) Add(string, int64) {}

View file

@ -2,7 +2,6 @@ package main
import (
"context"
"errors"
"expvar"
"flag"
"fmt"
@ -18,6 +17,9 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/kinesis"
alog "github.com/apex/log"
"github.com/apex/log/handlers/text"
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/ddb"
)
@ -26,7 +28,7 @@ import (
func init() {
sock, err := net.Listen("tcp", "localhost:8080")
if err != nil {
log.Fatalf("net listen error: %v", err)
log.Printf("net listen error: %v", err)
}
go func() {
fmt.Println("Metrics available at http://localhost:8080/debug/vars")
@ -34,7 +36,25 @@ func init() {
}()
}
// A myLogger provides a minimalistic logger satisfying the Logger interface.
type myLogger struct {
logger alog.Logger
}
// Log logs the parameters to the stdlib logger. See log.Println.
func (l *myLogger) Log(args ...interface{}) {
l.logger.Infof("producer", args...)
}
func main() {
// Wrap myLogger around apex logger
log := &myLogger{
logger: alog.Logger{
Handler: text.New(os.Stdout),
Level: alog.DebugLevel,
},
}
var (
app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
@ -45,36 +65,37 @@ func main() {
// Following will overwrite the default dynamodb client
// Older versions of aws sdk does not picking up aws config properly.
// You probably need to update aws sdk verison. Tested the following with 1.13.59
myDynamoDbClient := dynamodb.New(session.New(aws.NewConfig()), &aws.Config{
Region: aws.String("us-west-2"),
})
myDynamoDbClient := dynamodb.New(
session.New(aws.NewConfig()), &aws.Config{
Region: aws.String("us-west-2"),
},
)
// ddb checkpoint
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{}))
if err != nil {
log.Fatalf("checkpoint error: %v", err)
log.Log("checkpoint error: %v", err)
}
var (
counter = expvar.NewMap("counters")
logger = log.New(os.Stdout, "", log.LstdFlags)
)
var counter = expvar.NewMap("counters")
// The following 2 lines will overwrite the default kinesis client
myKinesisClient := kinesis.New(session.New(aws.NewConfig()), &aws.Config{
Region: aws.String("us-west-2"),
})
newKclient := consumer.NewKinesisClient(consumer.WithKinesis(myKinesisClient))
ksis := kinesis.New(
session.New(aws.NewConfig()), &aws.Config{
Region: aws.String("us-west-2"),
},
)
// consumer
c, err := consumer.New(
*stream,
consumer.WithCheckpoint(ck),
consumer.WithLogger(logger),
consumer.WithLogger(log),
consumer.WithCounter(counter),
consumer.WithClient(newKclient),
consumer.WithClient(ksis),
)
if err != nil {
log.Fatalf("consumer error: %v", err)
log.Log("consumer error: %v", err)
}
// use cancel func to signal shutdown
@ -90,29 +111,27 @@ func main() {
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanError {
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
err := errors.New("some error happened")
// continue scanning
return consumer.ScanError{
Error: err,
StopScan: true,
SkipCheckpoint: false,
}
return consumer.ScanStatus{}
})
if err != nil {
log.Fatalf("scan error: %v", err)
log.Log("scan error: %v", err)
}
if err := ck.Shutdown(); err != nil {
log.Fatalf("checkpoint shutdown error: %v", err)
log.Log("checkpoint shutdown error: %v", err)
}
}
// MyRetryer used for checkpointing
type MyRetryer struct {
checkpoint.Retryer
}
// ShouldRetry implements custom logic for when a checkpont should retry
func (r *MyRetryer) ShouldRetry(err error) bool {
if awsErr, ok := err.(awserr.Error); ok {
switch awsErr.Code() {

View file

@ -0,0 +1,17 @@
# Consumer with postgres checkpoint
Read records from the Kinesis stream using postgres as checkpoint
## Environment Variables
Export the required environment vars for connecting to the Kinesis stream:
```shell
export AWS_ACCESS_KEY=
export AWS_REGION=
export AWS_SECRET_KEY=
```
## Run the consumer
go run main.go --app appName --stream streamName --table tableName --connection connectionString

View file

@ -0,0 +1,70 @@
package main
import (
"context"
"expvar"
"flag"
"fmt"
"log"
"os"
"os/signal"
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/postgres"
)
func main() {
var (
app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
table = flag.String("table", "", "Table name")
connStr = flag.String("connection", "", "Connection Str")
)
flag.Parse()
// postgres checkpoint
ck, err := checkpoint.New(*app, *table, *connStr)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
var counter = expvar.NewMap("counters")
// consumer
c, err := consumer.New(
*stream,
consumer.WithCheckpoint(ck),
consumer.WithCounter(counter),
)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
go func() {
<-signals
cancel()
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanStatus{}
})
if err != nil {
log.Fatalf("scan error: %v", err)
}
if err := ck.Shutdown(); err != nil {
log.Fatalf("checkpoint shutdown error: %v", err)
}
}

View file

@ -0,0 +1,18 @@
# Consumer
Read records from the Kinesis stream
### Environment Variables
Export the required environment vars for connecting to the Kinesis stream and Redis for checkpoint:
```
export AWS_ACCESS_KEY=
export AWS_REGION=
export AWS_SECRET_KEY=
export REDIS_URL=
```
### Run the consumer
$ go run main.go --app appName --stream streamName

View file

@ -0,0 +1,58 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"os/signal"
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
)
func main() {
var (
app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
)
flag.Parse()
// redis checkpoint
ck, err := checkpoint.New(*app)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
// consumer
c, err := consumer.New(
*stream, consumer.WithCheckpoint(ck),
)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
go func() {
<-signals
cancel()
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanStatus{}
})
if err != nil {
log.Fatalf("scan error: %v", err)
}
}

View file

@ -4,24 +4,20 @@ import (
"bufio"
"flag"
"fmt"
"log"
"os"
"time"
"github.com/apex/log"
"github.com/apex/log/handlers/text"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
)
var svc = kinesis.New(session.New(), &aws.Config{
Region: aws.String("us-west-2"),
Region: aws.String("us-west-1"),
})
func main() {
log.SetHandler(text.New(os.Stderr))
log.SetLevel(log.DebugLevel)
var streamName = flag.String("stream", "", "Stream name")
flag.Parse()
@ -60,7 +56,7 @@ func putRecords(streamName *string, records []*kinesis.PutRecordsRequestEntry) {
Records: records,
})
if err != nil {
log.Fatal("error putting records")
log.Fatalf("error putting records: %v", err)
}
fmt.Print(".")
}

20
logger.go Normal file
View file

@ -0,0 +1,20 @@
package consumer
import (
"log"
)
// A Logger is a minimal interface to as a adaptor for external logging library to consumer
type Logger interface {
Log(...interface{})
}
// noopLogger implements logger interface with discard
type noopLogger struct {
logger *log.Logger
}
// Log using stdlib logger. See log.Println.
func (l noopLogger) Log(args ...interface{}) {
l.logger.Println(args...)
}