diff --git a/allgroup.go b/allgroup.go index 1c6e5e4..a8330e9 100644 --- a/allgroup.go +++ b/allgroup.go @@ -33,11 +33,8 @@ type AllGroup struct { // start is a blocking operation which will loop and attempt to find new // shards on a regular cadence. -func (g *AllGroup) Start(ctx context.Context) chan *kinesis.Shard { - var ( - shardc = make(chan *kinesis.Shard, 1) - ticker = time.NewTicker(30 * time.Second) - ) +func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) { + var ticker = time.NewTicker(30 * time.Second) g.findNewShards(shardc) // Note: while ticker is a rather naive approach to this problem, @@ -49,19 +46,15 @@ func (g *AllGroup) Start(ctx context.Context) chan *kinesis.Shard { // necessarily close at the same time, so we could potentially get a // thundering heard of notifications from the consumer. - go func() { - for { - select { - case <-ctx.Done(): - ticker.Stop() - return - case <-ticker.C: - g.findNewShards(shardc) - } + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + g.findNewShards(shardc) } - }() - - return shardc + } } func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) { @@ -83,7 +76,7 @@ func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) { shards, err := listShards(g.ksis, g.streamName) if err != nil { - g.logger.Log("[GROUP]", err) + g.logger.Log("[GROUP] error:", err) return } diff --git a/consumer.go b/consumer.go index 8596525..1ea5c5c 100644 --- a/consumer.go +++ b/consumer.go @@ -88,10 +88,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { var ( errc = make(chan error, 1) - shardc = c.group.Start(ctx) + shardc = make(chan *kinesis.Shard, 1) ) go func() { + c.group.Start(ctx, shardc) <-ctx.Done() close(shardc) }() @@ -130,9 +131,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return fmt.Errorf("get shard iterator error: %v", err) } - c.logger.Log("[START]\t", shardID, lastSeqNum) + c.logger.Log("[CONSUMER] start scan:", shardID, lastSeqNum) defer func() { - c.logger.Log("[STOP]\t", shardID) + c.logger.Log("[CONSUMER] stop scan:", shardID) }() for { @@ -176,7 +177,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } if isShardClosed(resp.NextShardIterator, shardIterator) { - c.logger.Log("[CLOSED]\t", shardID) + c.logger.Log("[CONSUMER] shard closed:", shardID) return nil } diff --git a/examples/producer/README.md b/examples/producer/README.md index a620e95..da7c13b 100644 --- a/examples/producer/README.md +++ b/examples/producer/README.md @@ -8,7 +8,7 @@ Export the required environment vars for connecting to the Kinesis stream: ``` export AWS_PROFILE= -export AWS_REGION_NAME= +export AWS_REGION= ``` ### Running the code diff --git a/group.go b/group.go index 25e3fad..aa08438 100644 --- a/group.go +++ b/group.go @@ -8,7 +8,7 @@ import ( // Group interface used to manage which shard to process type Group interface { - Start(ctx context.Context) chan *kinesis.Shard + Start(ctx context.Context, shardc chan *kinesis.Shard) GetCheckpoint(streamName, shardID string) (string, error) SetCheckpoint(streamName, shardID, sequenceNumber string) error }