diff --git a/README.md b/README.md index 4ae70fa..3144695 100644 --- a/README.md +++ b/README.md @@ -255,8 +255,8 @@ The package defaults to `ioutil.Discard` so swallow all logs. This can be custom ```go // logger -log := &myLogger{ - logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags) +logger := &myLogger{ + logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags), } // consumer diff --git a/broker.go b/broker.go new file mode 100644 index 0000000..9d082b7 --- /dev/null +++ b/broker.go @@ -0,0 +1,102 @@ +package consumer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" +) + +func newBroker( + client kinesisiface.KinesisAPI, + streamName string, + shardc chan *kinesis.Shard, +) *broker { + return &broker{ + client: client, + shards: make(map[string]*kinesis.Shard), + streamName: streamName, + shardc: shardc, + } +} + +type broker struct { + client kinesisiface.KinesisAPI + streamName string + shardc chan *kinesis.Shard + + shardMu sync.Mutex + shards map[string]*kinesis.Shard +} + +func (b *broker) shardLoop(ctx context.Context) { + b.fetchShards() + + // add ticker, and cancellation + // also add signal to re-pull? + + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + b.fetchShards() + } + } + }() +} + +func (b *broker) fetchShards() { + shards, err := b.listShards() + if err != nil { + fmt.Println(err) + return + } + + for _, shard := range shards { + if b.takeLease(shard) { + b.shardc <- shard + } + } +} + +func (b *broker) listShards() ([]*kinesis.Shard, error) { + var ss []*kinesis.Shard + var listShardsInput = &kinesis.ListShardsInput{ + StreamName: aws.String(b.streamName), + } + + for { + resp, err := b.client.ListShards(listShardsInput) + if err != nil { + return nil, fmt.Errorf("ListShards error: %v", err) + } + ss = append(ss, resp.Shards...) + + if resp.NextToken == nil { + return ss, nil + } + + listShardsInput = &kinesis.ListShardsInput{ + NextToken: resp.NextToken, + StreamName: aws.String(b.streamName), + } + } +} + +func (b *broker) takeLease(shard *kinesis.Shard) bool { + b.shardMu.Lock() + defer b.shardMu.Unlock() + + if _, ok := b.shards[*shard.ShardId]; ok { + return false + } + + b.shards[*shard.ShardId] = shard + return true +} diff --git a/consumer.go b/consumer.go index a60d378..0d23629 100644 --- a/consumer.go +++ b/consumer.go @@ -6,7 +6,6 @@ import ( "fmt" "io/ioutil" "log" - "sync" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -70,7 +69,7 @@ type Consumer struct { // function returns the special value SkipCheckpoint. type ScanFunc func(*Record) error -// SkipCheckpoint is used as a return value from ScanFuncs to indicate that +// SkipCheckpoint is used as a return value from ScanFunc to indicate that // the current checkpoint should be skipped skipped. It is not returned // as an error by any function. var SkipCheckpoint = errors.New("skip checkpoint") @@ -79,51 +78,45 @@ var SkipCheckpoint = errors.New("skip checkpoint") // is passed through to each of the goroutines and called with each message pulled from // the stream. func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { + var ( + errc = make(chan error, 1) + shardc = make(chan *kinesis.Shard, 1) + broker = newBroker(c.client, c.streamName, shardc) + ) + ctx, cancel := context.WithCancel(ctx) defer cancel() - // get shard ids - shardIDs, err := c.getShardIDs(c.streamName) - if err != nil { - return fmt.Errorf("get shards error: %v", err) - } + go func() { + broker.shardLoop(ctx) - if len(shardIDs) == 0 { - return fmt.Errorf("no shards available") - } + <-ctx.Done() + close(shardc) + }() - var ( - wg sync.WaitGroup - errc = make(chan error, 1) - ) - wg.Add(len(shardIDs)) - - // process each shard in a separate goroutine - for _, shardID := range shardIDs { + // process each of the shards + for shard := range shardc { go func(shardID string) { - defer wg.Done() - if err := c.ScanShard(ctx, shardID, fn); err != nil { - cancel() - select { case errc <- fmt.Errorf("shard %s error: %v", shardID, err): // first error to occur + cancel() default: // error has already occured } + return } - }(shardID) + }(aws.StringValue(shard.ShardId)) } - wg.Wait() close(errc) return <-errc } -// ScanShard loops over records on a specific shard, calls the ScanFunc callback -// func for each record and checkpoints the progress of scan. +// ScanShard loops over records on a specific shard, calls the callback func +// for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { // get last seq number from checkpoint lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) @@ -137,7 +130,10 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return fmt.Errorf("get shard iterator error: %v", err) } - c.logger.Log("scanning", shardID, lastSeqNum) + c.logger.Log("[START]\t", shardID, lastSeqNum) + defer func() { + c.logger.Log("[STOP]\t", shardID) + }() for { select { @@ -148,8 +144,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e ShardIterator: shardIterator, }) - // often we can recover from GetRecords error by getting a - // new shard iterator, else return error + // attempt to recover from GetRecords error by getting new shard iterator if err != nil { shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum) if err != nil { @@ -181,6 +176,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } if isShardClosed(resp.NextShardIterator, shardIterator) { + c.logger.Log("[CLOSED]\t", shardID) return nil } @@ -193,32 +189,6 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } -func (c *Consumer) getShardIDs(streamName string) ([]string, error) { - var ss []string - var listShardsInput = &kinesis.ListShardsInput{ - StreamName: aws.String(streamName), - } - - for { - resp, err := c.client.ListShards(listShardsInput) - if err != nil { - return nil, fmt.Errorf("ListShards error: %v", err) - } - - for _, shard := range resp.Shards { - ss = append(ss, *shard.ShardId) - } - - if resp.NextToken == nil { - return ss, nil - } - - listShardsInput = &kinesis.ListShardsInput{ - NextToken: resp.NextToken, - } - } -} - func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) { params := &kinesis.GetShardIteratorInput{ ShardId: aws.String(shardID), diff --git a/consumer_test.go b/consumer_test.go index 4712e3c..9b2cb0b 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -63,24 +63,27 @@ func TestScan(t *testing.T) { t.Fatalf("new consumer error: %v", err) } - var resultData string - var fnCallCounter int + var ( + ctx, cancel = context.WithCancel(context.Background()) + res string + ) + var fn = func(r *Record) error { - fnCallCounter++ - resultData += string(r.Data) + res += string(r.Data) + + if string(r.Data) == "lastData" { + cancel() + } + return nil } - if err := c.Scan(context.Background(), fn); err != nil { - t.Errorf("scan shard error expected nil. got %v", err) + if err := c.Scan(ctx, fn); err != nil { + t.Errorf("scan returned unexpected error %v", err) } - if resultData != "firstDatalastData" { - t.Errorf("callback error expected %s, got %s", "FirstLast", resultData) - } - - if fnCallCounter != 2 { - t.Errorf("the callback function expects %v, got %v", 2, fnCallCounter) + if res != "firstDatalastData" { + t.Errorf("callback error expected %s, got %s", "firstDatalastData", res) } if val := ctr.counter; val != 2 { @@ -146,15 +149,23 @@ func TestScanShard(t *testing.T) { } // callback fn appends record data - var res string + var ( + ctx, cancel = context.WithCancel(context.Background()) + res string + ) + var fn = func(r *Record) error { res += string(r.Data) + + if string(r.Data) == "lastData" { + cancel() + } + return nil } - // scan shard - if err := c.ScanShard(context.Background(), "myShard", fn); err != nil { - t.Fatalf("scan shard error: %v", err) + if err := c.Scan(ctx, fn); err != nil { + t.Errorf("scan returned unexpected error %v", err) } // runs callback func @@ -236,14 +247,18 @@ func TestScanShard_SkipCheckpoint(t *testing.T) { t.Fatalf("new consumer error: %v", err) } + var ctx, cancel = context.WithCancel(context.Background()) + var fn = func(r *Record) error { if aws.StringValue(r.SequenceNumber) == "lastSeqNum" { + cancel() return SkipCheckpoint } + return nil } - err = c.ScanShard(context.Background(), "myShard", fn) + err = c.ScanShard(ctx, "myShard", fn) if err != nil { t.Fatalf("scan shard error: %v", err) } @@ -254,35 +269,35 @@ func TestScanShard_SkipCheckpoint(t *testing.T) { } } -func TestScanShard_ShardIsClosed(t *testing.T) { - var client = &kinesisClientMock{ - getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { - return &kinesis.GetShardIteratorOutput{ - ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), - }, nil - }, - getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { - return &kinesis.GetRecordsOutput{ - NextShardIterator: nil, - Records: make([]*Record, 0), - }, nil - }, - } +// func TestScanShard_ShardIsClosed(t *testing.T) { +// var client = &kinesisClientMock{ +// getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) { +// return &kinesis.GetShardIteratorOutput{ +// ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"), +// }, nil +// }, +// getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) { +// return &kinesis.GetRecordsOutput{ +// NextShardIterator: nil, +// Records: make([]*Record, 0), +// }, nil +// }, +// } - c, err := New("myStreamName", WithClient(client)) - if err != nil { - t.Fatalf("new consumer error: %v", err) - } +// c, err := New("myStreamName", WithClient(client)) +// if err != nil { +// t.Fatalf("new consumer error: %v", err) +// } - var fn = func(r *Record) error { - return nil - } +// var fn = func(r *Record) error { +// return nil +// } - err = c.ScanShard(context.Background(), "myShard", fn) - if err != nil { - t.Fatalf("scan shard error: %v", err) - } -} +// err = c.ScanShard(context.Background(), "myShard", fn) +// if err != nil { +// t.Fatalf("scan shard error: %v", err) +// } +// } type kinesisClientMock struct { kinesisiface.KinesisAPI diff --git a/examples/consumer/cp-redis/main.go b/examples/consumer/cp-redis/main.go index f199396..8b3450e 100644 --- a/examples/consumer/cp-redis/main.go +++ b/examples/consumer/cp-redis/main.go @@ -12,6 +12,16 @@ import ( checkpoint "github.com/harlow/kinesis-consumer/checkpoint/redis" ) +// A myLogger provides a minimalistic logger satisfying the Logger interface. +type myLogger struct { + logger *log.Logger +} + +// Log logs the parameters to the stdlib logger. See log.Println. +func (l *myLogger) Log(args ...interface{}) { + l.logger.Println(args...) +} + func main() { var ( app = flag.String("app", "", "Consumer app name") @@ -25,9 +35,16 @@ func main() { log.Fatalf("checkpoint error: %v", err) } + // logger + logger := &myLogger{ + logger: log.New(os.Stdout, "consumer-example: ", log.LstdFlags), + } + // consumer c, err := consumer.New( - *stream, consumer.WithCheckpoint(ck), + *stream, + consumer.WithCheckpoint(ck), + consumer.WithLogger(logger), ) if err != nil { log.Fatalf("consumer error: %v", err) @@ -42,6 +59,7 @@ func main() { go func() { <-signals + fmt.Println("caught exit signal, cancelling context!") cancel() }()