adjust channel ownership for group

This commit is contained in:
Harlow Ward 2019-06-02 17:46:28 -07:00
parent 4fd29c54ff
commit 0328cba5c9
4 changed files with 18 additions and 24 deletions

View file

@ -33,11 +33,8 @@ type AllGroup struct {
// start is a blocking operation which will loop and attempt to find new // start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence. // shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context) chan *kinesis.Shard { func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) {
var ( var ticker = time.NewTicker(30 * time.Second)
shardc = make(chan *kinesis.Shard, 1)
ticker = time.NewTicker(30 * time.Second)
)
g.findNewShards(shardc) g.findNewShards(shardc)
// Note: while ticker is a rather naive approach to this problem, // Note: while ticker is a rather naive approach to this problem,
@ -49,7 +46,6 @@ func (g *AllGroup) Start(ctx context.Context) chan *kinesis.Shard {
// necessarily close at the same time, so we could potentially get a // necessarily close at the same time, so we could potentially get a
// thundering heard of notifications from the consumer. // thundering heard of notifications from the consumer.
go func() {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -59,9 +55,6 @@ func (g *AllGroup) Start(ctx context.Context) chan *kinesis.Shard {
g.findNewShards(shardc) g.findNewShards(shardc)
} }
} }
}()
return shardc
} }
func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) { func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) {
@ -83,7 +76,7 @@ func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) {
shards, err := listShards(g.ksis, g.streamName) shards, err := listShards(g.ksis, g.streamName)
if err != nil { if err != nil {
g.logger.Log("[GROUP]", err) g.logger.Log("[GROUP] error:", err)
return return
} }

View file

@ -88,10 +88,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
var ( var (
errc = make(chan error, 1) errc = make(chan error, 1)
shardc = c.group.Start(ctx) shardc = make(chan *kinesis.Shard, 1)
) )
go func() { go func() {
c.group.Start(ctx, shardc)
<-ctx.Done() <-ctx.Done()
close(shardc) close(shardc)
}() }()
@ -130,9 +131,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
return fmt.Errorf("get shard iterator error: %v", err) return fmt.Errorf("get shard iterator error: %v", err)
} }
c.logger.Log("[START]\t", shardID, lastSeqNum) c.logger.Log("[CONSUMER] start scan:", shardID, lastSeqNum)
defer func() { defer func() {
c.logger.Log("[STOP]\t", shardID) c.logger.Log("[CONSUMER] stop scan:", shardID)
}() }()
for { for {
@ -176,7 +177,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
} }
if isShardClosed(resp.NextShardIterator, shardIterator) { if isShardClosed(resp.NextShardIterator, shardIterator) {
c.logger.Log("[CLOSED]\t", shardID) c.logger.Log("[CONSUMER] shard closed:", shardID)
return nil return nil
} }

View file

@ -8,7 +8,7 @@ Export the required environment vars for connecting to the Kinesis stream:
``` ```
export AWS_PROFILE= export AWS_PROFILE=
export AWS_REGION_NAME= export AWS_REGION=
``` ```
### Running the code ### Running the code

View file

@ -8,7 +8,7 @@ import (
// Group interface used to manage which shard to process // Group interface used to manage which shard to process
type Group interface { type Group interface {
Start(ctx context.Context) chan *kinesis.Shard Start(ctx context.Context, shardc chan *kinesis.Shard)
GetCheckpoint(streamName, shardID string) (string, error) GetCheckpoint(streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error SetCheckpoint(streamName, shardID, sequenceNumber string) error
} }