Return the shard scan errors to top-level caller

This commit is contained in:
Harlow Ward 2017-11-23 08:49:37 -08:00
parent b0245d688b
commit 86f1df782e
2 changed files with 45 additions and 35 deletions

View file

@ -3,8 +3,8 @@ package consumer
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"os"
"sync" "sync"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -72,7 +72,7 @@ func New(stream string, opts ...Option) (*Consumer, error) {
streamName: stream, streamName: stream,
checkpoint: &noopCheckpoint{}, checkpoint: &noopCheckpoint{},
counter: &noopCounter{}, counter: &noopCounter{},
logger: log.New(os.Stderr, "kinesis-consumer: ", log.LstdFlags), logger: log.New(ioutil.Discard, "", log.LstdFlags),
} }
// set options // set options
@ -112,48 +112,58 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
}, },
) )
if err != nil { if err != nil {
return err return fmt.Errorf("describe stream error: %v", err)
} }
var wg sync.WaitGroup var (
wg sync.WaitGroup
errc = make(chan error, 1)
)
wg.Add(len(resp.StreamDescription.Shards)) wg.Add(len(resp.StreamDescription.Shards))
// launch goroutine to process each of the shards // launch goroutine to process each of the shards
for _, shard := range resp.StreamDescription.Shards { for _, shard := range resp.StreamDescription.Shards {
go func(shardID string) { go func(shardID string) {
defer wg.Done() defer wg.Done()
c.ScanShard(ctx, shardID, fn) err := c.ScanShard(ctx, shardID, fn)
if err != nil {
select {
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
// first error to occur
default:
// error has already occured
}
}
c.logger.Println("exiting", shardID)
cancel() cancel()
}(*shard.ShardId) }(*shard.ShardId)
} }
wg.Wait() wg.Wait()
return nil close(errc)
return <-errc
} }
// ScanShard loops over records on a specific shard, calls the callback func // ScanShard loops over records on a specific shard, calls the callback func
// for each record and checkpoints after each page is processed. // for each record and checkpoints after each page is processed.
// Note: returning `false` from the callback func will end the scan. // Note: returning `false` from the callback func will end the scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) error {
c.logger.Println("scanning", shardID)
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
if err != nil { if err != nil {
c.logger.Printf("get checkpoint error: %v", err) return fmt.Errorf("get checkpoint error: %v", err)
return
} }
shardIterator, err := c.getShardIterator(shardID, lastSeqNum) shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
if err != nil { if err != nil {
c.logger.Printf("get shard iterator error: %v", err) return fmt.Errorf("get shard iterator error: %v", err)
return
} }
c.logger.Println("scanning", shardID)
loop:
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
break loop return nil
default: default:
resp, err := c.client.GetRecords( resp, err := c.client.GetRecords(
&kinesis.GetRecordsInput{ &kinesis.GetRecordsInput{
@ -164,8 +174,7 @@ loop:
if err != nil { if err != nil {
shardIterator, err = c.getShardIterator(shardID, lastSeqNum) shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil { if err != nil {
c.logger.Printf("get shard iterator error: %v", err) return fmt.Errorf("get shard iterator error: %v", err)
break loop
} }
continue continue
} }
@ -174,44 +183,45 @@ loop:
for _, r := range resp.Records { for _, r := range resp.Records {
select { select {
case <-ctx.Done(): case <-ctx.Done():
break loop return nil
default: default:
lastSeqNum = *r.SequenceNumber lastSeqNum = *r.SequenceNumber
c.counter.Add("records", 1) c.counter.Add("records", 1)
if ok := fn(r); !ok { if ok := fn(r); !ok {
break loop if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
return err
}
return nil
} }
} }
} }
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil { if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err) return err
} }
c.logger.Println("checkpoint", shardID, len(resp.Records))
c.counter.Add("checkpoints", 1)
} }
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator { if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
shardIterator, err = c.getShardIterator(shardID, lastSeqNum) shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
if err != nil { if err != nil {
c.logger.Printf("get shard iterator error: %v", err) return fmt.Errorf("get shard iterator error: %v", err)
break loop
} }
} else { } else {
shardIterator = resp.NextShardIterator shardIterator = resp.NextShardIterator
} }
} }
} }
}
if lastSeqNum == "" { func (c *Consumer) setCheckpoint(shardID, lastSeqNum string) error {
return err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum)
} if err != nil {
return fmt.Errorf("set checkpoint error: %v", err)
c.logger.Println("checkpointing", shardID)
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
c.logger.Printf("set checkpoint error: %v", err)
} }
c.logger.Println("checkpoint", shardID)
c.counter.Add("checkpoints", 1)
return nil
} }
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) { func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {

View file

@ -35,10 +35,10 @@ func main() {
var ( var (
counter = expvar.NewMap("counters") counter = expvar.NewMap("counters")
logger = log.New(os.Stdout, "consumer-example: ", log.LstdFlags) logger = log.New(os.Stdout, "", log.LstdFlags)
) )
// checkpoint // redis checkpoint
ck, err := checkpoint.New(*app) ck, err := checkpoint.New(*app)
if err != nil { if err != nil {
log.Fatalf("checkpoint error: %v", err) log.Fatalf("checkpoint error: %v", err)