Remove the client wrapper (#58)

Having an additional Client has added some confusion (https://github.com/harlow/kinesis-consumer/issues/45) on how to provide a
custom kinesis client. Allowing `WithClient` to accept a Kinesis client
it cleans up the interface.

Major changes:

* Remove the Client wrapper; prefer using kinesis client directly
* Change `ScanError` to `ScanStatus` as the return value isn't necessarily an error

Note: these are breaking changes, if you need last stable release please see here: https://github.com/harlow/kinesis-consumer/releases/tag/v0.2.0
This commit is contained in:
Harlow Ward 2018-07-28 22:53:33 -07:00 committed by GitHub
parent 049445e259
commit fb98fbe244
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 358 additions and 527 deletions

View file

@ -4,25 +4,26 @@ All notable changes to this project will be documented in this file.
## [Unreleased (`master`)][unreleased]
** Breaking changes to consumer library **
Major changes:
* Use [functional options][options] for config
* Remove intermediate batching of kinesis records
* Call the callback func with each record
* Use dep for vendoring dependencies
* Add DDB as storage layer for checkpoints
* Remove concept of `Client` it was confusing as it wasn't a direct standin for a Kinesis client.
* Rename `ScanError` to `ScanStatus` as it's not always an error.
Minor changes:
* remove unused buffer and emitter code
* Update tests to use Kinesis mock
[unreleased]: https://github.com/harlow/kinesis-consumer/compare/v0.1.0...HEAD
[options]: https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
## v0.2.0 - 2018-07-28
This is the last stable release from which there is a separate Client. It has caused confusion and will be removed going forward.
https://github.com/harlow/kinesis-consumer/releases/tag/v0.2.0
## v0.1.0 - 2017-11-20
This is the last stable release of the consumer which aggregated records in `batch` before calling the callback func.
This is the last stable release of the consumer which aggregated records in `batch` before calling the callback func.
https://github.com/harlow/kinesis-consumer/releases/tag/v0.1.0
[unreleased]: https://github.com/harlow/kinesis-consumer/compare/v0.2.0...HEAD
[options]: https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis

View file

@ -37,14 +37,14 @@ func main() {
log.Fatalf("consumer error: %v", err)
}
// start
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanError {
// start scan
err = c.Scan(context.TODO(), func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanError{
StopScan: false, // true to stop scan
SkipCheckpoint: false, // true to skip checkpoint
}
return consumer.ScanStatus{
StopScan: false, // true to stop scan
SkipCheckpoint: false, // true to skip checkpoint
}
})
if err != nil {
log.Fatalf("scan error: %v", err)
@ -55,6 +55,24 @@ func main() {
}
```
## Scan status
The scan func returns a `consumer.ScanStatus` the struct allows some basic flow control.
```go
// continue scanning
return consumer.ScanStatus{}
// continue scanning, skip saving checkpoint
return consumer.ScanStatus{SkipCheckpoint: true}
// stop scanning, return nil
return consumer.ScanStatus{StopScan: true}
// stop scanning, return error
return consumer.ScanStatus{Error: err}
```
## Checkpoint
To record the progress of the consumer in the stream we use a checkpoint to store the last sequence number the consumer has read from a particular shard. The boolean value SkipCheckpoint of consumer.ScanError determines if checkpoint will be activated. ScanError is returned by the record processing callback.
@ -107,8 +125,9 @@ myDynamoDbClient := dynamodb.New(session.New(aws.NewConfig()))
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient))
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
log.Fatalf("new checkpoint error: %v", err)
}
// Or we can provide your own Retryer to customize what triggers a retry inside checkpoint
// See code in examples
// ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{}))
@ -133,7 +152,7 @@ import checkpoint "github.com/harlow/kinesis-consumer/checkpoint/postgres"
// postgres checkpoint
ck, err := checkpoint.New(app, table, connStr)
if err != nil {
log.Fatalf("new checkpoint error: %v", err)
log.Fatalf("new checkpoint error: %v", err)
}
```
@ -155,7 +174,7 @@ The table name has to be the same that you specify when creating the checkpoint.
The consumer allows the following optional overrides.
### Client
### Kinesis Client
Override the Kinesis client if there is any special config needed:
@ -189,6 +208,7 @@ The [expvar package](https://golang.org/pkg/expvar/) will display consumer count
```
### Logging
Logging supports the basic built-in logging library or use thrid party external one, so long as
it implements the Logger interface.
@ -197,12 +217,12 @@ For example, to use the builtin logging package, we wrap it with myLogger struct
```
// A myLogger provides a minimalistic logger satisfying the Logger interface.
type myLogger struct {
logger *log.Logger
logger *log.Logger
}
// Log logs the parameters to the stdlib logger. See log.Println.
func (l *myLogger) Log(args ...interface{}) {
l.logger.Println(args...)
l.logger.Println(args...)
}
```
@ -210,29 +230,32 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom
```go
// logger
log := &myLogger{ logger : log.New(os.Stdout, "consumer-example: ", log.LstdFlags),}
log := &myLogger{
logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
}
// consumer
c, err := consumer.New(streamName, consumer.WithLogger(logger))
```
To use a more complicated logging library, e.g. apex log
```
type myLogger struct {
logger *log.Logger
logger *log.Logger
}
func (l *myLogger) Log(args ...interface{}) {
l.logger.Infof("producer", args...)
l.logger.Infof("producer", args...)
}
func main() {
log := &myLogger{
logger: alog.Logger{
Handler: text.New(os.Stderr),
Level: alog.DebugLevel,
},
}
log := &myLogger{
logger: alog.Logger{
Handler: text.New(os.Stderr),
Level: alog.DebugLevel,
},
}
```
## Contributing

12
checkpoint.go Normal file
View file

@ -0,0 +1,12 @@
package consumer
// Checkpoint interface used track consumer progress in the stream
type Checkpoint interface {
Get(streamName, shardID string) (string, error)
Set(streamName, shardID, sequenceNumber string) error
}
type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string, string) error { return nil }
func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }

152
client.go
View file

@ -1,152 +0,0 @@
package consumer
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
// ClientOption is used to override defaults when creating a KinesisClient
type ClientOption func(*KinesisClient)
// WithKinesis overrides the default Kinesis client
func WithKinesis(svc kinesisiface.KinesisAPI) ClientOption {
return func(kc *KinesisClient) {
kc.svc = svc
}
}
// WithStartFromLatest will make sure the client start consuming
// events starting from the most recent event in kinesis. This
// option discards the checkpoints.
func WithStartFromLatest() ClientOption {
return func(kc *KinesisClient) {
kc.fromLatest = true
}
}
// NewKinesisClient returns client to interface with Kinesis stream
func NewKinesisClient(opts ...ClientOption) (*KinesisClient, error) {
kc := &KinesisClient{}
for _, opt := range opts {
opt(kc)
}
newSession, err := session.NewSession(aws.NewConfig())
if err != nil {
return nil, err
}
if kc.svc == nil {
kc.svc = kinesis.New(newSession)
}
return kc, nil
}
// KinesisClient acts as wrapper around Kinesis client
type KinesisClient struct {
svc kinesisiface.KinesisAPI
fromLatest bool
}
// GetShardIDs returns shard ids in a given stream
func (c *KinesisClient) GetShardIDs(streamName string) ([]string, error) {
resp, err := c.svc.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
var ss []string
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}
return ss, nil
}
// GetRecords returns a chan Record from a Shard of the Stream
func (c *KinesisClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
shardIterator, err := c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
return nil, nil, fmt.Errorf("get shard iterator error: %v", err)
}
var (
recc = make(chan *Record, 10000)
errc = make(chan error, 1)
)
go func() {
defer func() {
close(recc)
close(errc)
}()
for {
select {
case <-ctx.Done():
return
default:
resp, err := c.svc.GetRecords(
&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
},
)
if err != nil {
shardIterator, err = c.getShardIterator(streamName, shardID, lastSeqNum)
if err != nil {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
continue
}
for _, r := range resp.Records {
select {
case <-ctx.Done():
return
case recc <- r:
lastSeqNum = *r.SequenceNumber
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
errc <- fmt.Errorf("get shard iterator error: %v", err)
return
}
shardIterator = resp.NextShardIterator
}
}
}()
return recc, errc, nil
}
func (c *KinesisClient) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if c.fromLatest {
params.ShardIteratorType = aws.String("LATEST")
} else if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.svc.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
}

View file

@ -1,150 +0,0 @@
package consumer_test
import (
"testing"
"context"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
"github.com/harlow/kinesis-consumer"
)
func TestKinesisClient_GetRecords_SuccessfullyRun(t *testing.T) {
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*kinesis.Record, 0),
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("New kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
recordsChan, errorsChan, err := c.GetRecords(ctx, "myStream", "shardId-000000000000", "")
if recordsChan == nil {
t.Errorf("records channel expected not nil, got %v", recordsChan)
}
if errorsChan == nil {
t.Errorf("errors channel expected not nil, got %v", recordsChan)
}
if err != nil {
t.Errorf("error expected nil, got %v", err)
}
cancelFunc()
}
func TestKinesisClient_GetRecords_SuccessfullyRetrievesThreeRecordsAtOnce(t *testing.T) {
expectedResults := []*kinesis.Record{
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576195"),
},
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576196"),
},
{
SequenceNumber: aws.String("49578481031144599192696750682534686652010819674221576197"),
}}
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: expectedResults,
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("new kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
recordsChan, _, err := c.GetRecords(ctx, "TestStream", "shardId-000000000000", "")
if recordsChan == nil {
t.Fatalf("records channel expected not nil, got %v", recordsChan)
}
if err != nil {
t.Fatalf("error expected nil, got %v", err)
}
var results []*consumer.Record
results = append(results, <-recordsChan, <-recordsChan, <-recordsChan)
if len(results) != 3 {
t.Errorf("number of records expected 3, got %v", len(results))
}
for i, r := range results {
if r != expectedResults[i] {
t.Errorf("record expected %v, got %v", expectedResults[i], r)
}
}
cancelFunc()
}
func TestKinesisClient_GetRecords_ShardIsClosed(t *testing.T) {
kinesisClient := &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*consumer.Record, 0),
}, nil
},
}
kinesisClientOpt := consumer.WithKinesis(kinesisClient)
c, err := consumer.NewKinesisClient(kinesisClientOpt)
if err != nil {
t.Fatalf("new kinesis client error: %v", err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
_, errorsChan, err := c.GetRecords(ctx, "TestStream", "shardId-000000000000", "")
if errorsChan == nil {
t.Fatalf("errors channel expected equals not nil, got %v", errorsChan)
}
if err != nil {
t.Fatalf("error expected, got %v", err)
}
err = <-errorsChan
if err == nil {
t.Errorf("error expected, got %v", err)
}
cancelFunc()
}
type kinesisClientMock struct {
kinesisiface.KinesisAPI
getShardIteratorMock func(*kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error)
getRecordsMock func(*kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error)
}
func (c *kinesisClientMock) GetRecords(in *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return c.getRecordsMock(in)
}
func (c *kinesisClientMock) GetShardIterator(in *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return c.getShardIteratorMock(in)
}

View file

@ -3,83 +3,58 @@ package consumer
import (
"context"
"fmt"
"io/ioutil"
"log"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
// ScanError signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanError struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// Record is an alias of record returned from kinesis library
type Record = kinesis.Record
// Client interface is used for interacting with kinesis stream
type Client interface {
GetShardIDs(string) ([]string, error)
GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error)
}
// Counter interface is used for exposing basic metrics from the scanner
type Counter interface {
Add(string, int64)
}
type noopCounter struct{}
func (n noopCounter) Add(string, int64) {}
// Checkpoint interface used track consumer progress in the stream
type Checkpoint interface {
Get(streamName, shardID string) (string, error)
Set(streamName, shardID, sequenceNumber string) error
}
type noopCheckpoint struct{}
func (n noopCheckpoint) Set(string, string, string) error { return nil }
func (n noopCheckpoint) Get(string, string) (string, error) { return "", nil }
// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer) error
type Option func(*Consumer)
// WithCheckpoint overrides the default checkpoint
func WithCheckpoint(checkpoint Checkpoint) Option {
return func(c *Consumer) error {
return func(c *Consumer) {
c.checkpoint = checkpoint
return nil
}
}
// WithLogger overrides the default logger
func WithLogger(logger Logger) Option {
return func(c *Consumer) error {
return func(c *Consumer) {
c.logger = logger
return nil
}
}
// WithCounter overrides the default counter
func WithCounter(counter Counter) Option {
return func(c *Consumer) error {
return func(c *Consumer) {
c.counter = counter
return nil
}
}
// WithClient overrides the default client
func WithClient(client Client) Option {
return func(c *Consumer) error {
func WithClient(client kinesisiface.KinesisAPI) Option {
return func(c *Consumer) {
c.client = client
return nil
}
}
// ScanStatus signals the consumer if we should continue scanning for next record
// and whether to checkpoint.
type ScanStatus struct {
Error error
StopScan bool
SkipCheckpoint bool
}
// New creates a kinesis consumer with default settings. Use Option to override
// any of the optional attributes.
func New(streamName string, opts ...Option) (*Consumer, error) {
@ -87,25 +62,24 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
return nil, fmt.Errorf("must provide stream name")
}
kc, err := NewKinesisClient()
if err != nil {
return nil, err
}
// new consumer with no-op checkpoint, counter, and logger
c := &Consumer{
streamName: streamName,
checkpoint: &noopCheckpoint{},
counter: &noopCounter{},
logger: NewDefaultLogger(),
client: kc,
logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
},
}
// override defaults
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
}
opt(c)
}
// default client if none provided
if c.client == nil {
c.client = kinesis.New(session.New(aws.NewConfig()))
}
return c, nil
@ -114,7 +88,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
// Consumer wraps the interaction with the Kinesis stream
type Consumer struct {
streamName string
client Client
client kinesisiface.KinesisAPI
logger Logger
checkpoint Checkpoint
counter Counter
@ -122,8 +96,12 @@ type Consumer struct {
// Scan scans each of the shards of the stream, calls the callback
// func with each of the kinesis records.
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error {
shardIDs, err := c.client.GetShardIDs(c.streamName)
func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanStatus) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// get shard ids
shardIDs, err := c.getShardIDs(c.streamName)
if err != nil {
return fmt.Errorf("get shards error: %v", err)
}
@ -132,16 +110,13 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error {
return fmt.Errorf("no shards available")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var (
wg sync.WaitGroup
errc = make(chan error, 1)
)
wg.Add(len(shardIDs))
// process each shard in goroutine
// process each shard in a separate goroutine
for _, shardID := range shardIDs {
go func(shardID string) {
defer wg.Done()
@ -161,47 +136,113 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*Record) ScanError) error {
wg.Wait()
close(errc)
return <-errc
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
// Note: Returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Record) ScanError) (err error) {
func (c *Consumer) ScanShard(
ctx context.Context,
shardID string,
fn func(*Record) ScanStatus,
) error {
// get checkpoint
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil {
return fmt.Errorf("get checkpoint error: %v", err)
}
// get shard iterator
shardIterator, err := c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
c.logger.Log("scanning", shardID, lastSeqNum)
// get records
recc, errc, err := c.client.GetRecords(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get records error: %v", err)
}
// loop records
for r := range recc {
scanError := fn(r)
// scan pages of shard
for {
select {
case <-ctx.Done():
return nil
default:
resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{
ShardIterator: shardIterator,
})
// Skip invalid state
if scanError.StopScan && scanError.SkipCheckpoint {
continue
}
if scanError.StopScan {
break
}
if !scanError.SkipCheckpoint {
c.counter.Add("records", 1)
err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber)
if err != nil {
return fmt.Errorf("set checkpoint error: %v", err)
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
continue
}
// loop records of page
for _, r := range resp.Records {
status := fn(r)
if !status.SkipCheckpoint {
lastSeqNum = *r.SequenceNumber
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
return err
}
}
if err := status.Error; err != nil {
return err
}
c.counter.Add("records", 1)
if status.StopScan {
return nil
}
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
return nil
}
shardIterator = resp.NextShardIterator
}
}
c.logger.Log("exiting", shardID)
return <-errc
}
func (c *Consumer) getShardIDs(streamName string) ([]string, error) {
resp, err := c.client.DescribeStream(
&kinesis.DescribeStreamInput{
StreamName: aws.String(streamName),
},
)
if err != nil {
return nil, fmt.Errorf("describe stream error: %v", err)
}
ss := []string{}
for _, shard := range resp.StreamDescription.Shards {
ss = append(ss, *shard.ShardId)
}
return ss, nil
}
func (c *Consumer) getShardIterator(streamName, shardID, lastSeqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
}
if lastSeqNum != "" {
params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER")
params.StartingSequenceNumber = aws.String(lastSeqNum)
} else {
params.ShardIteratorType = aws.String("TRIM_HORIZON")
}
resp, err := c.client.GetShardIterator(params)
if err != nil {
return nil, err
}
return resp.ShardIterator, nil
}

View file

@ -3,12 +3,14 @@ package consumer
import (
"context"
"fmt"
"io/ioutil"
"log"
"sync"
"testing"
"errors"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)
func TestNew(t *testing.T) {
@ -20,41 +22,36 @@ func TestNew(t *testing.T) {
func TestScanShard(t *testing.T) {
var (
ckp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
client = newFakeClient(
&Record{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&Record{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
)
resultData string
ckp = &fakeCheckpoint{cache: map[string]string{}}
ctr = &fakeCounter{}
mockSvc = &mockKinesisClient{}
logger = &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
}
)
c := &Consumer{
streamName: "myStreamName",
client: client,
client: mockSvc,
checkpoint: ckp,
counter: ctr,
logger: NewDefaultLogger(),
logger: logger,
}
var recordNum = 0
// callback fn simply appends the record data to result string
var (
resultData string
fn = func(r *Record) ScanError {
resultData += string(r.Data)
err := errors.New("some error happened")
return ScanError{
Error: err,
StopScan: false,
SkipCheckpoint: false,
}
var fn = func(r *Record) ScanStatus {
resultData += string(r.Data)
recordNum++
stopScan := recordNum == 2
return ScanStatus{
StopScan: stopScan,
SkipCheckpoint: false,
}
)
}
// scan shard
err := c.ScanShard(context.Background(), "myShard", fn)
@ -79,34 +76,30 @@ func TestScanShard(t *testing.T) {
}
}
func newFakeClient(rs ...*Record) *fakeClient {
fc := &fakeClient{
recc: make(chan *Record, len(rs)),
errc: make(chan error),
}
for _, r := range rs {
fc.recc <- r
}
close(fc.errc)
close(fc.recc)
return fc
type mockKinesisClient struct {
kinesisiface.KinesisAPI
}
type fakeClient struct {
shardIDs []string
recc chan *Record
errc chan error
func (m *mockKinesisClient) GetRecords(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
Records: []*kinesis.Record{
&kinesis.Record{
Data: []byte("firstData"),
SequenceNumber: aws.String("firstSeqNum"),
},
&kinesis.Record{
Data: []byte("lastData"),
SequenceNumber: aws.String("lastSeqNum"),
},
},
}, nil
}
func (fc *fakeClient) GetShardIDs(string) ([]string, error) {
return fc.shardIDs, nil
}
func (fc *fakeClient) GetRecords(ctx context.Context, streamName, shardID, lastSeqNum string) (<-chan *Record, <-chan error, error) {
return fc.recc, fc.errc, nil
func (m *mockKinesisClient) GetShardIterator(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("myshard"),
}, nil
}
type fakeCheckpoint struct {

10
counter.go Normal file
View file

@ -0,0 +1,10 @@
package consumer
// Counter interface is used for exposing basic metrics from the scanner
type Counter interface {
Add(string, int64)
}
type noopCounter struct{}
func (n noopCounter) Add(string, int64) {}

View file

@ -2,7 +2,6 @@ package main
import (
"context"
"errors"
"expvar"
"flag"
"fmt"
@ -29,7 +28,7 @@ import (
func init() {
sock, err := net.Listen("tcp", "localhost:8080")
if err != nil {
log.Println("net listen error: %v", err)
log.Printf("net listen error: %v", err)
}
go func() {
fmt.Println("Metrics available at http://localhost:8080/debug/vars")
@ -66,24 +65,26 @@ func main() {
// Following will overwrite the default dynamodb client
// Older versions of aws sdk does not picking up aws config properly.
// You probably need to update aws sdk verison. Tested the following with 1.13.59
myDynamoDbClient := dynamodb.New(session.New(aws.NewConfig()), &aws.Config{
Region: aws.String("us-west-2"),
})
myDynamoDbClient := dynamodb.New(
session.New(aws.NewConfig()), &aws.Config{
Region: aws.String("us-west-2"),
},
)
// ddb checkpoint
ck, err := checkpoint.New(*app, *table, checkpoint.WithDynamoClient(myDynamoDbClient), checkpoint.WithRetryer(&MyRetryer{}))
if err != nil {
log.Log("checkpoint error: %v", err)
}
var (
counter = expvar.NewMap("counters")
)
var counter = expvar.NewMap("counters")
// The following 2 lines will overwrite the default kinesis client
myKinesisClient := kinesis.New(session.New(aws.NewConfig()), &aws.Config{
Region: aws.String("us-west-2"),
})
newKclient := consumer.NewKinesisClient(consumer.WithKinesis(myKinesisClient))
ksis := kinesis.New(
session.New(aws.NewConfig()), &aws.Config{
Region: aws.String("us-west-2"),
},
)
// consumer
c, err := consumer.New(
@ -91,7 +92,7 @@ func main() {
consumer.WithCheckpoint(ck),
consumer.WithLogger(log),
consumer.WithCounter(counter),
consumer.WithClient(newKclient),
consumer.WithClient(ksis),
)
if err != nil {
log.Log("consumer error: %v", err)
@ -110,15 +111,11 @@ func main() {
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanError {
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
err := errors.New("some error happened")
// continue scanning
return consumer.ScanError{
Error: err,
StopScan: true,
SkipCheckpoint: false,
}
return consumer.ScanStatus{}
})
if err != nil {
log.Log("scan error: %v", err)
@ -129,10 +126,12 @@ func main() {
}
}
// MyRetryer used for checkpointing
type MyRetryer struct {
checkpoint.Retryer
}
// ShouldRetry implements custom logic for when a checkpont should retry
func (r *MyRetryer) ShouldRetry(err error) bool {
if awsErr, ok := err.(awserr.Error); ok {
switch awsErr.Code() {

View file

@ -2,7 +2,6 @@ package main
import (
"context"
"errors"
"expvar"
"flag"
"fmt"
@ -29,24 +28,19 @@ func main() {
log.Fatalf("checkpoint error: %v", err)
}
var (
counter = expvar.NewMap("counters")
)
newKclient := consumer.NewKinesisClient()
var counter = expvar.NewMap("counters")
// consumer
c, err := consumer.New(
*stream,
consumer.WithCheckpoint(ck),
consumer.WithCounter(counter),
consumer.WithClient(newKclient),
)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
// use cancel \func to signal shutdown
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
// trap SIGINT, wait to trigger shutdown
@ -59,15 +53,11 @@ func main() {
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanError {
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
err := errors.New("some error happened")
// continue scanning
return consumer.ScanError{
Error: err,
StopScan: false,
SkipCheckpoint: false,
}
return consumer.ScanStatus{}
})
if err != nil {

View file

@ -0,0 +1,18 @@
# Consumer
Read records from the Kinesis stream
### Environment Variables
Export the required environment vars for connecting to the Kinesis stream and Redis for checkpoint:
```
export AWS_ACCESS_KEY=
export AWS_REGION=
export AWS_SECRET_KEY=
export REDIS_URL=
```
### Run the consumer
$ go run main.go --app appName --stream streamName

View file

@ -0,0 +1,58 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"os/signal"
consumer "github.com/harlow/kinesis-consumer"
checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis"
)
func main() {
var (
app = flag.String("app", "", "App name")
stream = flag.String("stream", "", "Stream name")
)
flag.Parse()
// redis checkpoint
ck, err := checkpoint.New(*app)
if err != nil {
log.Fatalf("checkpoint error: %v", err)
}
// consumer
c, err := consumer.New(
*stream, consumer.WithCheckpoint(ck),
)
if err != nil {
log.Fatalf("consumer error: %v", err)
}
// use cancel func to signal shutdown
ctx, cancel := context.WithCancel(context.Background())
// trap SIGINT, wait to trigger shutdown
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt)
go func() {
<-signals
cancel()
}()
// scan stream
err = c.Scan(ctx, func(r *consumer.Record) consumer.ScanStatus {
fmt.Println(string(r.Data))
// continue scanning
return consumer.ScanStatus{}
})
if err != nil {
log.Fatalf("scan error: %v", err)
}
}

View file

@ -4,24 +4,20 @@ import (
"bufio"
"flag"
"fmt"
"log"
"os"
"time"
"github.com/apex/log"
"github.com/apex/log/handlers/text"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
)
var svc = kinesis.New(session.New(), &aws.Config{
Region: aws.String("us-west-2"),
Region: aws.String("us-west-1"),
})
func main() {
log.SetHandler(text.New(os.Stderr))
log.SetLevel(log.DebugLevel)
var streamName = flag.String("stream", "", "Stream name")
flag.Parse()
@ -60,7 +56,7 @@ func putRecords(streamName *string, records []*kinesis.PutRecordsRequestEntry) {
Records: records,
})
if err != nil {
log.Fatal("error putting records")
log.Fatalf("error putting records: %v", err)
}
fmt.Print(".")
}

View file

@ -1,7 +1,6 @@
package consumer
import (
"io/ioutil"
"log"
)
@ -12,19 +11,12 @@ type Logger interface {
type LoggerFunc func(...interface{})
// NewDefaultLogger returns a Logger which discards messages.
func NewDefaultLogger() Logger {
return &defaultLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
}
}
// A defaultLogger provides a logging instance when none is provided.
type defaultLogger struct {
// noopLogger implements logger interface with discard
type noopLogger struct {
logger *log.Logger
}
// Log using stdlib logger. See log.Println.
func (l defaultLogger) Log(args ...interface{}) {
func (l noopLogger) Log(args ...interface{}) {
l.logger.Println(args...)
}