Return the shard scan errors to top-level caller

This commit is contained in:
Harlow Ward 2017-11-23 08:49:37 -08:00
parent b0245d688b
commit 86f1df782e
2 changed files with 45 additions and 35 deletions

View file

@ -3,8 +3,8 @@ package consumer
import (
"context"
"fmt"
"io/ioutil"
"log"
"os"
"sync"
"github.com/aws/aws-sdk-go/aws"
@ -72,7 +72,7 @@ func New(stream string, opts ...Option) (*Consumer, error) {
streamName: stream,
checkpoint: &noopCheckpoint{},
counter: &noopCounter{},
logger: log.New(os.Stderr, "kinesis-consumer: ", log.LstdFlags),
logger: log.New(ioutil.Discard, "", log.LstdFlags),
}
// set options
@ -112,48 +112,58 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
},
)
if err != nil {
return err
return fmt.Errorf("describe stream error: %v", err)
}
var wg sync.WaitGroup
var (
wg sync.WaitGroup
errc = make(chan error, 1)
)
wg.Add(len(resp.StreamDescription.Shards))
// launch goroutine to process each of the shards
for _, shard := range resp.StreamDescription.Shards {
go func(shardID string) {
defer wg.Done()
c.ScanShard(ctx, shardID, fn)
err := c.ScanShard(ctx, shardID, fn)
if err != nil {
select {
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
// first error to occur
default:
// error has already occured
}
}
c.logger.Println("exiting", shardID)
cancel()
}(*shard.ShardId)
}
wg.Wait()
return nil
close(errc)
return <-errc
}
// ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints after each page is processed.
// Note: returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) error {
c.logger.Println("scanning", shardID)
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil {
c.logger.Printf("get checkpoint error: %v", err)
return
return fmt.Errorf("get checkpoint error: %v", err)
}
shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
if err != nil {
c.logger.Printf("get shard iterator error: %v", err)
return
return fmt.Errorf("get shard iterator error: %v", err)
}
c.logger.Println("scanning", shardID)
loop:
for {
select {
case <-ctx.Done():
break loop
return nil
default:
resp, err := c.client.GetRecords(
&kinesis.GetRecordsInput{
@ -164,8 +174,7 @@ loop:
if err != nil {
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil {
c.logger.Printf("get shard iterator error: %v", err)
break loop
return fmt.Errorf("get shard iterator error: %v", err)
}
continue
}
@ -174,44 +183,45 @@ loop:
for _, r := range resp.Records {
select {
case <-ctx.Done():
break loop
return nil
default:
lastSeqNum = *r.SequenceNumber
c.counter.Add("records", 1)
if ok := fn(r); !ok {
break loop
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
return err
}
return nil
}
}
}
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err)
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
return err
}
c.logger.Println("checkpoint", shardID, len(resp.Records))
c.counter.Add("checkpoints", 1)
}
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil {
c.logger.Printf("get shard iterator error: %v", err)
break loop
return fmt.Errorf("get shard iterator error: %v", err)
}
} else {
shardIterator = resp.NextShardIterator
}
}
}
}
if lastSeqNum == "" {
return
}
c.logger.Println("checkpointing", shardID)
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err)
func (c *Consumer) setCheckpoint(shardID, lastSeqNum string) error {
err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("set checkpoint error: %v", err)
}
c.logger.Println("checkpoint", shardID)
c.counter.Add("checkpoints", 1)
return nil
}
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {

View file

@ -35,10 +35,10 @@ func main() {
var (
counter = expvar.NewMap("counters")
logger = log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
logger = log.New(os.Stdout, "", log.LstdFlags)
)
// checkpoint
// redis checkpoint
ck, err := checkpoint.New(*app)
if err != nil {
log.Fatalf("checkpoint error: %v", err)