diff --git a/allgroup.go b/allgroup.go new file mode 100644 index 0000000..a8330e9 --- /dev/null +++ b/allgroup.go @@ -0,0 +1,90 @@ +package consumer + +import ( + "context" + "sync" + "time" + + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" +) + +func NewAllGroup(ksis kinesisiface.KinesisAPI, ck Checkpoint, streamName string, logger Logger) *AllGroup { + return &AllGroup{ + ksis: ksis, + shards: make(map[string]*kinesis.Shard), + streamName: streamName, + logger: logger, + checkpoint: ck, + } +} + +// AllGroup caches a local list of the shards we are already processing +// and routinely polls the stream looking for new shards to process +type AllGroup struct { + ksis kinesisiface.KinesisAPI + streamName string + logger Logger + checkpoint Checkpoint + + shardMu sync.Mutex + shards map[string]*kinesis.Shard +} + +// start is a blocking operation which will loop and attempt to find new +// shards on a regular cadence. +func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) { + var ticker = time.NewTicker(30 * time.Second) + g.findNewShards(shardc) + + // Note: while ticker is a rather naive approach to this problem, + // it actually simplies a few things. i.e. If we miss a new shard while + // AWS is resharding we'll pick it up max 30 seconds later. + + // It might be worth refactoring this flow to allow the consumer to + // to notify the broker when a shard is closed. However, shards don't + // necessarily close at the same time, so we could potentially get a + // thundering heard of notifications from the consumer. + + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + g.findNewShards(shardc) + } + } +} + +func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) { + return g.checkpoint.Get(streamName, shardID) +} + +func (g *AllGroup) SetCheckpoint(streamName, shardID, sequenceNumber string) error { + return g.checkpoint.Set(streamName, shardID, sequenceNumber) +} + +// findNewShards pulls the list of shards from the Kinesis API +// and uses a local cache to determine if we are already processing +// a particular shard. +func (g *AllGroup) findNewShards(shardc chan *kinesis.Shard) { + g.shardMu.Lock() + defer g.shardMu.Unlock() + + g.logger.Log("[GROUP]", "fetching shards") + + shards, err := listShards(g.ksis, g.streamName) + if err != nil { + g.logger.Log("[GROUP] error:", err) + return + } + + for _, shard := range shards { + if _, ok := g.shards[*shard.ShardId]; ok { + continue + } + g.shards[*shard.ShardId] = shard + shardc <- shard + } +} diff --git a/broker.go b/broker.go deleted file mode 100644 index ecf25a1..0000000 --- a/broker.go +++ /dev/null @@ -1,114 +0,0 @@ -package consumer - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" -) - -func newBroker( - client kinesisiface.KinesisAPI, - streamName string, - shardc chan *kinesis.Shard, - logger Logger, -) *broker { - return &broker{ - client: client, - shards: make(map[string]*kinesis.Shard), - streamName: streamName, - shardc: shardc, - logger: logger, - } -} - -// broker caches a local list of the shards we are already processing -// and routinely polls the stream looking for new shards to process -type broker struct { - client kinesisiface.KinesisAPI - streamName string - shardc chan *kinesis.Shard - logger Logger - - shardMu sync.Mutex - shards map[string]*kinesis.Shard -} - -// start is a blocking operation which will loop and attempt to find new -// shards on a regular cadence. -func (b *broker) start(ctx context.Context) { - b.findNewShards() - ticker := time.NewTicker(30 * time.Second) - - // Note: while ticker is a rather naive approach to this problem, - // it actually simplies a few things. i.e. If we miss a new shard while - // AWS is resharding we'll pick it up max 30 seconds later. - - // It might be worth refactoring this flow to allow the consumer to - // to notify the broker when a shard is closed. However, shards don't - // necessarily close at the same time, so we could potentially get a - // thundering heard of notifications from the consumer. - - for { - select { - case <-ctx.Done(): - ticker.Stop() - return - case <-ticker.C: - b.findNewShards() - } - } -} - -// findNewShards pulls the list of shards from the Kinesis API -// and uses a local cache to determine if we are already processing -// a particular shard. -func (b *broker) findNewShards() { - b.shardMu.Lock() - defer b.shardMu.Unlock() - - b.logger.Log("[BROKER]", "fetching shards") - - shards, err := b.listShards() - if err != nil { - b.logger.Log("[BROKER]", err) - return - } - - for _, shard := range shards { - if _, ok := b.shards[*shard.ShardId]; ok { - continue - } - b.shards[*shard.ShardId] = shard - b.shardc <- shard - } -} - -// listShards pulls a list of shard IDs from the kinesis api -func (b *broker) listShards() ([]*kinesis.Shard, error) { - var ss []*kinesis.Shard - var listShardsInput = &kinesis.ListShardsInput{ - StreamName: aws.String(b.streamName), - } - - for { - resp, err := b.client.ListShards(listShardsInput) - if err != nil { - return nil, fmt.Errorf("ListShards error: %v", err) - } - ss = append(ss, resp.Shards...) - - if resp.NextToken == nil { - return ss, nil - } - - listShardsInput = &kinesis.ListShardsInput{ - NextToken: resp.NextToken, - StreamName: aws.String(b.streamName), - } - } -} diff --git a/consumer.go b/consumer.go index e9c583a..1ea5c5c 100644 --- a/consumer.go +++ b/consumer.go @@ -27,8 +27,8 @@ func New(streamName string, opts ...Option) (*Consumer, error) { c := &Consumer{ streamName: streamName, initialShardIteratorType: kinesis.ShardIteratorTypeLatest, - checkpoint: &noopCheckpoint{}, counter: &noopCounter{}, + checkpoint: &noopCheckpoint{}, logger: &noopLogger{ logger: log.New(ioutil.Discard, "", log.LstdFlags), }, @@ -48,6 +48,11 @@ func New(streamName string, opts ...Option) (*Consumer, error) { c.client = kinesis.New(newSession) } + // default group if none provided + if c.group == nil { + c.group = NewAllGroup(c.client, c.checkpoint, c.streamName, c.logger) + } + return c, nil } @@ -57,6 +62,7 @@ type Consumer struct { initialShardIteratorType string client kinesisiface.KinesisAPI logger Logger + group Group checkpoint Checkpoint counter Counter } @@ -64,7 +70,6 @@ type Consumer struct { // ScanFunc is the type of the function called for each message read // from the stream. The record argument contains the original record // returned from the AWS Kinesis library. -// // If an error is returned, scanning stops. The sole exception is when the // function returns the special value SkipCheckpoint. type ScanFunc func(*Record) error @@ -78,18 +83,16 @@ var SkipCheckpoint = errors.New("skip checkpoint") // is passed through to each of the goroutines and called with each message pulled from // the stream. func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { - var ( - errc = make(chan error, 1) - shardc = make(chan *kinesis.Shard, 1) - broker = newBroker(c.client, c.streamName, shardc, c.logger) - ) - ctx, cancel := context.WithCancel(ctx) defer cancel() - go broker.start(ctx) + var ( + errc = make(chan error, 1) + shardc = make(chan *kinesis.Shard, 1) + ) go func() { + c.group.Start(ctx, shardc) <-ctx.Done() close(shardc) }() @@ -110,7 +113,6 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { } close(errc) - return <-errc } @@ -118,7 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { // get last seq number from checkpoint - lastSeqNum, err := c.checkpoint.Get(c.streamName, shardID) + lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID) if err != nil { return fmt.Errorf("get checkpoint error: %v", err) } @@ -129,9 +131,9 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return fmt.Errorf("get shard iterator error: %v", err) } - c.logger.Log("[START]\t", shardID, lastSeqNum) + c.logger.Log("[CONSUMER] start scan:", shardID, lastSeqNum) defer func() { - c.logger.Log("[STOP]\t", shardID) + c.logger.Log("[CONSUMER] stop scan:", shardID) }() for { @@ -164,7 +166,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } if err != SkipCheckpoint { - if err := c.checkpoint.Set(c.streamName, shardID, *r.SequenceNumber); err != nil { + if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { return err } } @@ -175,7 +177,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } if isShardClosed(resp.NextShardIterator, shardIterator) { - c.logger.Log("[CLOSED]\t", shardID) + c.logger.Log("[CONSUMER] shard closed:", shardID) return nil } diff --git a/examples/producer/README.md b/examples/producer/README.md index a620e95..da7c13b 100644 --- a/examples/producer/README.md +++ b/examples/producer/README.md @@ -8,7 +8,7 @@ Export the required environment vars for connecting to the Kinesis stream: ``` export AWS_PROFILE= -export AWS_REGION_NAME= +export AWS_REGION= ``` ### Running the code diff --git a/group.go b/group.go new file mode 100644 index 0000000..aa08438 --- /dev/null +++ b/group.go @@ -0,0 +1,14 @@ +package consumer + +import ( + "context" + + "github.com/aws/aws-sdk-go/service/kinesis" +) + +// Group interface used to manage which shard to process +type Group interface { + Start(ctx context.Context, shardc chan *kinesis.Shard) + GetCheckpoint(streamName, shardID string) (string, error) + SetCheckpoint(streamName, shardID, sequenceNumber string) error +} diff --git a/kinesis.go b/kinesis.go new file mode 100644 index 0000000..490b73b --- /dev/null +++ b/kinesis.go @@ -0,0 +1,34 @@ +package consumer + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" +) + +// listShards pulls a list of shard IDs from the kinesis api +func listShards(ksis kinesisiface.KinesisAPI, streamName string) ([]*kinesis.Shard, error) { + var ss []*kinesis.Shard + var listShardsInput = &kinesis.ListShardsInput{ + StreamName: aws.String(streamName), + } + + for { + resp, err := ksis.ListShards(listShardsInput) + if err != nil { + return nil, fmt.Errorf("ListShards error: %v", err) + } + ss = append(ss, resp.Shards...) + + if resp.NextToken == nil { + return ss, nil + } + + listShardsInput = &kinesis.ListShardsInput{ + NextToken: resp.NextToken, + StreamName: aws.String(streamName), + } + } +}