diff --git a/README.md b/README.md index 03f2078..0c686b9 100644 --- a/README.md +++ b/README.md @@ -31,19 +31,23 @@ func main() { c, err := consumer.New(*app, *stream) if err != nil { - log.Fatalf("new consumer error: %v", err) + log.Fatalf("consumer error: %v", err) } - c.Scan(context.TODO(), func(r *kinesis.Record) bool { + err = c.Scan(context.TODO(), func(r *kinesis.Record) bool { fmt.Println(string(r.Data)) return true // continue scanning }) + if err != nil { + log.Fatalf("scan error: %v", err) + } + + // Note: If you need to aggregate based on a specific shard the `ScanShard` + // method should be leverged instead. } ``` -Note: If you need to aggregate based on a specific shard the `ScanShard` method should be leverged instead. - ### Configuration The consumer requires the following config: diff --git a/consumer.go b/consumer.go index 1dcae1f..03f2deb 100644 --- a/consumer.go +++ b/consumer.go @@ -3,7 +3,6 @@ package consumer import ( "context" "fmt" - "os" "sync" "github.com/apex/log" @@ -100,26 +99,25 @@ type Consumer struct { } // Scan scans each of the shards of the stream, calls the callback -// func with each of the kinesis records -func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) { +// func with each of the kinesis records. +func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) error { ctx, cancel := context.WithCancel(ctx) defer cancel() + // grab the stream details resp, err := c.svc.DescribeStream( &kinesis.DescribeStreamInput{ StreamName: aws.String(c.streamName), }, ) - if err != nil { - c.logger.WithError(err).Error("DescribeStream") - os.Exit(1) + return err } var wg sync.WaitGroup wg.Add(len(resp.StreamDescription.Shards)) - // scan each of the shards + // launch goroutine to process each of the shards for _, shard := range resp.StreamDescription.Shards { go func(shardID string) { defer wg.Done() @@ -129,10 +127,12 @@ func (c *Consumer) Scan(ctx context.Context, fn func(*kinesis.Record) bool) { } wg.Wait() + return nil } -// ScanShard loops over records on a kinesis shard, call the callback func -// for each record and checkpoints after each page is processed +// ScanShard loops over records on a specific shard, calls the callback func +// for each record and checkpoints after each page is processed. +// Note: returning `false` from the callback func will end the scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*kinesis.Record) bool) { var ( logger = c.logger.WithFields(log.Fields{"shard": shardID}) diff --git a/examples/consumer/main.go b/examples/consumer/main.go index 9c54423..8c52270 100644 --- a/examples/consumer/main.go +++ b/examples/consumer/main.go @@ -24,11 +24,14 @@ func main() { c, err := consumer.New(*app, *stream) if err != nil { - log.Fatalf("new consumer error: %v", err) + log.Fatalf("consumer error: %v", err) } - c.Scan(context.TODO(), func(r *kinesis.Record) bool { + err = c.Scan(context.TODO(), func(r *kinesis.Record) bool { fmt.Println(string(r.Data)) return true // continue scanning }) + if err != nil { + log.Fatalf("scan error: %v", err) + } }