Return the shard scan errors to top-level caller
This commit is contained in:
parent
b0245d688b
commit
86f1df782e
2 changed files with 45 additions and 35 deletions
76
consumer.go
76
consumer.go
|
|
@ -3,8 +3,8 @@ package consumer
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
|
|
@ -72,7 +72,7 @@ func New(stream string, opts ...Option) (*Consumer, error) {
|
|||
streamName: stream,
|
||||
checkpoint: &noopCheckpoint{},
|
||||
counter: &noopCounter{},
|
||||
logger: log.New(os.Stderr, "kinesis-consumer: ", log.LstdFlags),
|
||||
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
||||
}
|
||||
|
||||
// set options
|
||||
|
|
@ -112,48 +112,58 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
|
|||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("describe stream error: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
errc = make(chan error, 1)
|
||||
)
|
||||
wg.Add(len(resp.StreamDescription.Shards))
|
||||
|
||||
// launch goroutine to process each of the shards
|
||||
for _, shard := range resp.StreamDescription.Shards {
|
||||
go func(shardID string) {
|
||||
defer wg.Done()
|
||||
c.ScanShard(ctx, shardID, fn)
|
||||
err := c.ScanShard(ctx, shardID, fn)
|
||||
if err != nil {
|
||||
select {
|
||||
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
|
||||
// first error to occur
|
||||
default:
|
||||
// error has already occured
|
||||
}
|
||||
}
|
||||
c.logger.Println("exiting", shardID)
|
||||
cancel()
|
||||
}(*shard.ShardId)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return nil
|
||||
close(errc)
|
||||
return <-errc
|
||||
}
|
||||
|
||||
// ScanShard loops over records on a specific shard, calls the callback func
|
||||
// for each record and checkpoints after each page is processed.
|
||||
// Note: returning `false` from the callback func will end the scan.
|
||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
|
||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) error {
|
||||
c.logger.Println("scanning", shardID)
|
||||
|
||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
||||
if err != nil {
|
||||
c.logger.Printf("get checkpoint error: %v", err)
|
||||
return
|
||||
return fmt.Errorf("get checkpoint error: %v", err)
|
||||
}
|
||||
|
||||
shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
|
||||
if err != nil {
|
||||
c.logger.Printf("get shard iterator error: %v", err)
|
||||
return
|
||||
return fmt.Errorf("get shard iterator error: %v", err)
|
||||
}
|
||||
|
||||
c.logger.Println("scanning", shardID)
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
return nil
|
||||
default:
|
||||
resp, err := c.client.GetRecords(
|
||||
&kinesis.GetRecordsInput{
|
||||
|
|
@ -164,8 +174,7 @@ loop:
|
|||
if err != nil {
|
||||
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
||||
if err != nil {
|
||||
c.logger.Printf("get shard iterator error: %v", err)
|
||||
break loop
|
||||
return fmt.Errorf("get shard iterator error: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
|
@ -174,44 +183,45 @@ loop:
|
|||
for _, r := range resp.Records {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
return nil
|
||||
default:
|
||||
lastSeqNum = *r.SequenceNumber
|
||||
c.counter.Add("records", 1)
|
||||
|
||||
if ok := fn(r); !ok {
|
||||
break loop
|
||||
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
|
||||
c.logger.Printf("set checkpoint error: %v", err)
|
||||
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Println("checkpoint", shardID, len(resp.Records))
|
||||
c.counter.Add("checkpoints", 1)
|
||||
}
|
||||
|
||||
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
|
||||
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
||||
if err != nil {
|
||||
c.logger.Printf("get shard iterator error: %v", err)
|
||||
break loop
|
||||
return fmt.Errorf("get shard iterator error: %v", err)
|
||||
}
|
||||
} else {
|
||||
shardIterator = resp.NextShardIterator
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastSeqNum == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Println("checkpointing", shardID)
|
||||
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
|
||||
c.logger.Printf("set checkpoint error: %v", err)
|
||||
func (c *Consumer) setCheckpoint(shardID, lastSeqNum string) error {
|
||||
err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set checkpoint error: %v", err)
|
||||
}
|
||||
c.logger.Println("checkpoint", shardID)
|
||||
c.counter.Add("checkpoints", 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {
|
||||
|
|
|
|||
|
|
@ -35,10 +35,10 @@ func main() {
|
|||
|
||||
var (
|
||||
counter = expvar.NewMap("counters")
|
||||
logger = log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
|
||||
logger = log.New(os.Stdout, "", log.LstdFlags)
|
||||
)
|
||||
|
||||
// checkpoint
|
||||
// redis checkpoint
|
||||
ck, err := checkpoint.New(*app)
|
||||
if err != nil {
|
||||
log.Fatalf("checkpoint error: %v", err)
|
||||
|
|
|
|||
Loading…
Reference in a new issue