fix rebase errors

This commit is contained in:
Harlow Ward 2019-07-28 11:09:43 -07:00
parent c690ce2822
commit 8b19674b4a
3 changed files with 12 additions and 25 deletions

View file

@ -11,13 +11,13 @@ import (
// NewAllGroup returns an intitialized AllGroup for consuming // NewAllGroup returns an intitialized AllGroup for consuming
// all shards on a stream // all shards on a stream
func NewAllGroup(ksis kinesisiface.KinesisAPI, ck Checkpoint, streamName string, logger Logger) *AllGroup { func NewAllGroup(ksis kinesisiface.KinesisAPI, db Storage, streamName string, logger Logger) *AllGroup {
return &AllGroup{ return &AllGroup{
ksis: ksis, ksis: ksis,
shards: make(map[string]*kinesis.Shard), shards: make(map[string]*kinesis.Shard),
streamName: streamName, streamName: streamName,
logger: logger, logger: logger,
checkpoint: ck, storage: db,
} }
} }
@ -27,7 +27,7 @@ type AllGroup struct {
ksis kinesisiface.KinesisAPI ksis kinesisiface.KinesisAPI
streamName string streamName string
logger Logger logger Logger
checkpoint Checkpoint storage Storage
shardMu sync.Mutex shardMu sync.Mutex
shards map[string]*kinesis.Shard shards map[string]*kinesis.Shard
@ -61,12 +61,12 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) {
// GetCheckpoint returns the current checkpoint for provided stream // GetCheckpoint returns the current checkpoint for provided stream
func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) { func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) {
return g.checkpoint.Get(streamName, shardID) return g.storage.GetCheckpoint(streamName, shardID)
} }
// SetCheckpoint sets the current checkpoint for provided stream // SetCheckpoint sets the current checkpoint for provided stream
func (g *AllGroup) SetCheckpoint(streamName, shardID, sequenceNumber string) error { func (g *AllGroup) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
return g.checkpoint.Set(streamName, shardID, sequenceNumber) return g.storage.SetCheckpoint(streamName, shardID, sequenceNumber)
} }
// findNewShards pulls the list of shards from the Kinesis API // findNewShards pulls the list of shards from the Kinesis API

View file

@ -39,7 +39,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
opt(c) opt(c)
} }
// default client if none provided // default client
if c.client == nil { if c.client == nil {
newSession, err := session.NewSession(aws.NewConfig()) newSession, err := session.NewSession(aws.NewConfig())
if err != nil { if err != nil {
@ -48,9 +48,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
c.client = kinesis.New(newSession) c.client = kinesis.New(newSession)
} }
// default group if none provided // default group consumes all shards
if c.group == nil { if c.group == nil {
c.group = NewAllGroup(c.client, c.checkpoint, c.streamName, c.logger) c.group = NewAllGroup(c.client, c.storage, streamName, c.logger)
} }
return c, nil return c, nil
@ -61,14 +61,10 @@ type Consumer struct {
streamName string streamName string
initialShardIteratorType string initialShardIteratorType string
client kinesisiface.KinesisAPI client kinesisiface.KinesisAPI
logger Logger
<<<<<<< HEAD
group Group
checkpoint Checkpoint
=======
storage Storage
>>>>>>> 0162c90... Introduce Storage interface
counter Counter counter Counter
group Group
logger Logger
storage Storage
} }
// ScanFunc is the type of the function called for each message read // ScanFunc is the type of the function called for each message read
@ -124,11 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
// for each record and checkpoints the progress of scan. // for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
// get last seq number from checkpoint // get last seq number from checkpoint
<<<<<<< HEAD
lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID) lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID)
=======
lastSeqNum, err := c.storage.GetCheckpoint(c.streamName, shardID)
>>>>>>> 0162c90... Introduce Storage interface
if err != nil { if err != nil {
return fmt.Errorf("get checkpoint error: %v", err) return fmt.Errorf("get checkpoint error: %v", err)
} }
@ -173,13 +165,8 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
return err return err
} }
<<<<<<< HEAD
if err != ErrSkipCheckpoint { if err != ErrSkipCheckpoint {
if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
=======
if err != SkipCheckpoint {
if err := c.storage.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil {
>>>>>>> 0162c90... Introduce Storage interface
return err return err
} }
} }

View file

@ -229,7 +229,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
var fn = func(r *Record) error { var fn = func(r *Record) error {
if aws.StringValue(r.SequenceNumber) == "lastSeqNum" { if aws.StringValue(r.SequenceNumber) == "lastSeqNum" {
cancel() cancel()
return SkipCheckpoint return ErrSkipCheckpoint
} }
return nil return nil