diff --git a/consumer.go b/consumer.go index 2e0e688..2480710 100644 --- a/consumer.go +++ b/consumer.go @@ -3,8 +3,8 @@ package consumer import ( "context" "fmt" + "io/ioutil" "log" - "os" "sync" "github.com/aws/aws-sdk-go/aws" @@ -72,7 +72,7 @@ func New(stream string, opts ...Option) (*Consumer, error) { streamName: stream, checkpoint: &noopCheckpoint{}, counter: &noopCounter{}, - logger: log.New(os.Stderr, "kinesis-consumer: ", log.LstdFlags), + logger: log.New(ioutil.Discard, "", log.LstdFlags), } // set options @@ -112,48 +112,58 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) erro }, ) if err != nil { - return err + return fmt.Errorf("describe stream error: %v", err) } - var wg sync.WaitGroup + var ( + wg sync.WaitGroup + errc = make(chan error, 1) + ) wg.Add(len(resp.StreamDescription.Shards)) // launch goroutine to process each of the shards for _, shard := range resp.StreamDescription.Shards { go func(shardID string) { defer wg.Done() - c.ScanShard(ctx, shardID, fn) + err := c.ScanShard(ctx, shardID, fn) + if err != nil { + select { + case errc <- fmt.Errorf("shard %s error: %v", shardID, err): + // first error to occur + default: + // error has already occured + } + } + c.logger.Println("exiting", shardID) cancel() }(*shard.ShardId) } wg.Wait() - return nil + close(errc) + return <-errc } // ScanShard loops over records on a specific shard, calls the callback func // for each record and checkpoints after each page is processed. // Note: returning `false` from the callback func will end the scan. -func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) { +func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) error { + c.logger.Println("scanning", shardID) + lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) if err != nil { - c.logger.Printf("get checkpoint error: %v", err) - return + return fmt.Errorf("get checkpoint error: %v", err) } shardIterator, err := c.getShardIterator(shardID, lastSeqNum) if err != nil { - c.logger.Printf("get shard iterator error: %v", err) - return + return fmt.Errorf("get shard iterator error: %v", err) } - c.logger.Println("scanning", shardID) - -loop: for { select { case <-ctx.Done(): - break loop + return nil default: resp, err := c.client.GetRecords( &kinesis.GetRecordsInput{ @@ -164,8 +174,7 @@ loop: if err != nil { shardIterator, err = c.getShardIterator(shardID, lastSeqNum) if err != nil { - c.logger.Printf("get shard iterator error: %v", err) - break loop + return fmt.Errorf("get shard iterator error: %v", err) } continue } @@ -174,44 +183,45 @@ loop: for _, r := range resp.Records { select { case <-ctx.Done(): - break loop + return nil default: lastSeqNum = *r.SequenceNumber c.counter.Add("records", 1) + if ok := fn(r); !ok { - break loop + if err := c.setCheckpoint(shardID, lastSeqNum); err != nil { + return err + } + return nil } } } - if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil { - c.logger.Printf("set checkpoint error: %v", err) + if err := c.setCheckpoint(shardID, lastSeqNum); err != nil { + return err } - - c.logger.Println("checkpoint", shardID, len(resp.Records)) - c.counter.Add("checkpoints", 1) } if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator { shardIterator, err = c.getShardIterator(shardID, lastSeqNum) if err != nil { - c.logger.Printf("get shard iterator error: %v", err) - break loop + return fmt.Errorf("get shard iterator error: %v", err) } } else { shardIterator = resp.NextShardIterator } } } +} - if lastSeqNum == "" { - return - } - - c.logger.Println("checkpointing", shardID) - if err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum); err != nil { - c.logger.Printf("set checkpoint error: %v", err) +func (c *Consumer) setCheckpoint(shardID, lastSeqNum string) error { + err := c.checkpoint.Set(c.streamName, shardID, lastSeqNum) + if err != nil { + return fmt.Errorf("set checkpoint error: %v", err) } + c.logger.Println("checkpoint", shardID) + c.counter.Add("checkpoints", 1) + return nil } func (c *Consumer) getShardIterator(shardID, lastSeqNum string) (*string, error) { diff --git a/examples/consumer/main.go b/examples/consumer/main.go index 1c9ac47..528c5d2 100644 --- a/examples/consumer/main.go +++ b/examples/consumer/main.go @@ -35,10 +35,10 @@ func main() { var ( counter = expvar.NewMap("counters") - logger = log.New(os.Stdout, "consumer-example: ", log.LstdFlags) + logger = log.New(os.Stdout, "", log.LstdFlags) ) - // checkpoint + // redis checkpoint ck, err := checkpoint.New(*app) if err != nil { log.Fatalf("checkpoint error: %v", err)