Return the shard scan errors to top-level caller
This commit is contained in:
parent
b0245d688b
commit
86f1df782e
2 changed files with 45 additions and 35 deletions
76
consumer.go
76
consumer.go
|
|
@ -3,8 +3,8 @@ package consumer
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
|
@ -72,7 +72,7 @@ func New(stream string, opts ...Option) (*Consumer, error) {
|
||||||
streamName: stream,
|
streamName: stream,
|
||||||
checkpoint: &noopCheckpoint{},
|
checkpoint: &noopCheckpoint{},
|
||||||
counter: &noopCounter{},
|
counter: &noopCounter{},
|
||||||
logger: log.New(os.Stderr, "kinesis-consumer: ", log.LstdFlags),
|
logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
||||||
}
|
}
|
||||||
|
|
||||||
// set options
|
// set options
|
||||||
|
|
@ -112,48 +112,58 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("describe stream error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var (
|
||||||
|
wg sync.WaitGroup
|
||||||
|
errc = make(chan error, 1)
|
||||||
|
)
|
||||||
wg.Add(len(resp.StreamDescription.Shards))
|
wg.Add(len(resp.StreamDescription.Shards))
|
||||||
|
|
||||||
// launch goroutine to process each of the shards
|
// launch goroutine to process each of the shards
|
||||||
for _, shard := range resp.StreamDescription.Shards {
|
for _, shard := range resp.StreamDescription.Shards {
|
||||||
go func(shardID string) {
|
go func(shardID string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
c.ScanShard(ctx, shardID, fn)
|
err := c.ScanShard(ctx, shardID, fn)
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
|
||||||
|
// first error to occur
|
||||||
|
default:
|
||||||
|
// error has already occured
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.logger.Println("exiting", shardID)
|
||||||
cancel()
|
cancel()
|
||||||
}(*shard.ShardId)
|
}(*shard.ShardId)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return nil
|
close(errc)
|
||||||
|
return <-errc
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanShard loops over records on a specific shard, calls the callback func
|
// ScanShard loops over records on a specific shard, calls the callback func
|
||||||
// for each record and checkpoints after each page is processed.
|
// for each record and checkpoints after each page is processed.
|
||||||
// Note: returning `false` from the callback func will end the scan.
|
// Note: returning `false` from the callback func will end the scan.
|
||||||
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) {
|
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) error {
|
||||||
|
c.logger.Println("scanning", shardID)
|
||||||
|
|
||||||
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Printf("get checkpoint error: %v", err)
|
return fmt.Errorf("get checkpoint error: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
|
shardIterator, err := c.getShardIterator(shardID, lastSeqNum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Printf("get shard iterator error: %v", err)
|
return fmt.Errorf("get shard iterator error: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Println("scanning", shardID)
|
|
||||||
|
|
||||||
loop:
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
break loop
|
return nil
|
||||||
default:
|
default:
|
||||||
resp, err := c.client.GetRecords(
|
resp, err := c.client.GetRecords(
|
||||||
&kinesis.GetRecordsInput{
|
&kinesis.GetRecordsInput{
|
||||||
|
|
@ -164,8 +174,7 @@ loop:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Printf("get shard iterator error: %v", err)
|
return fmt.Errorf("get shard iterator error: %v", err)
|
||||||
break loop
|
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -174,44 +183,45 @@ loop:
|
||||||
for _, r := range resp.Records {
|
for _, r := range resp.Records {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
break loop
|
return nil
|
||||||
default:
|
default:
|
||||||
lastSeqNum = *r.SequenceNumber
|
lastSeqNum = *r.SequenceNumber
|
||||||
c.counter.Add("records", 1)
|
c.counter.Add("records", 1)
|
||||||
|
|
||||||
if ok := fn(r); !ok {
|
if ok := fn(r); !ok {
|
||||||
break loop
|
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
|
if err := c.setCheckpoint(shardID, lastSeqNum); err != nil {
|
||||||
c.logger.Printf("set checkpoint error: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Println("checkpoint", shardID, len(resp.Records))
|
|
||||||
c.counter.Add("checkpoints", 1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
|
if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator {
|
||||||
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
shardIterator, err = c.getShardIterator(shardID, lastSeqNum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Printf("get shard iterator error: %v", err)
|
return fmt.Errorf("get shard iterator error: %v", err)
|
||||||
break loop
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
shardIterator = resp.NextShardIterator
|
shardIterator = resp.NextShardIterator
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if lastSeqNum == "" {
|
func (c *Consumer) setCheckpoint(shardID, lastSeqNum string) error {
|
||||||
return
|
err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum)
|
||||||
}
|
if err != nil {
|
||||||
|
return fmt.Errorf("set checkpoint error: %v", err)
|
||||||
c.logger.Println("checkpointing", shardID)
|
|
||||||
if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil {
|
|
||||||
c.logger.Printf("set checkpoint error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
c.logger.Println("checkpoint", shardID)
|
||||||
|
c.counter.Add("checkpoints", 1)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {
|
func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) {
|
||||||
|
|
|
||||||
|
|
@ -35,10 +35,10 @@ func main() {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
counter = expvar.NewMap("counters")
|
counter = expvar.NewMap("counters")
|
||||||
logger = log.New(os.Stdout, "consumer-example: ", log.LstdFlags)
|
logger = log.New(os.Stdout, "", log.LstdFlags)
|
||||||
)
|
)
|
||||||
|
|
||||||
// checkpoint
|
// redis checkpoint
|
||||||
ck, err := checkpoint.New(*app)
|
ck, err := checkpoint.New(*app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("checkpoint error: %v", err)
|
log.Fatalf("checkpoint error: %v", err)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue