diff --git a/allgroup.go b/allgroup.go index f046f1d..7fa489e 100644 --- a/allgroup.go +++ b/allgroup.go @@ -11,13 +11,13 @@ import ( // NewAllGroup returns an intitialized AllGroup for consuming // all shards on a stream -func NewAllGroup(ksis kinesisiface.KinesisAPI, ck Checkpoint, streamName string, logger Logger) *AllGroup { +func NewAllGroup(ksis kinesisiface.KinesisAPI, db Storage, streamName string, logger Logger) *AllGroup { return &AllGroup{ ksis: ksis, shards: make(map[string]*kinesis.Shard), streamName: streamName, logger: logger, - checkpoint: ck, + storage: db, } } @@ -27,7 +27,7 @@ type AllGroup struct { ksis kinesisiface.KinesisAPI streamName string logger Logger - checkpoint Checkpoint + storage Storage shardMu sync.Mutex shards map[string]*kinesis.Shard @@ -61,12 +61,12 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan *kinesis.Shard) { // GetCheckpoint returns the current checkpoint for provided stream func (g *AllGroup) GetCheckpoint(streamName, shardID string) (string, error) { - return g.checkpoint.Get(streamName, shardID) + return g.storage.GetCheckpoint(streamName, shardID) } // SetCheckpoint sets the current checkpoint for provided stream func (g *AllGroup) SetCheckpoint(streamName, shardID, sequenceNumber string) error { - return g.checkpoint.Set(streamName, shardID, sequenceNumber) + return g.storage.SetCheckpoint(streamName, shardID, sequenceNumber) } // findNewShards pulls the list of shards from the Kinesis API diff --git a/consumer.go b/consumer.go index 9ce6aa6..bc59ac7 100644 --- a/consumer.go +++ b/consumer.go @@ -39,7 +39,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) { opt(c) } - // default client if none provided + // default client if c.client == nil { newSession, err := session.NewSession(aws.NewConfig()) if err != nil { @@ -48,9 +48,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) { c.client = kinesis.New(newSession) } - // default group if none provided + // default group consumes all shards if c.group == nil { - c.group = NewAllGroup(c.client, c.checkpoint, c.streamName, c.logger) + c.group = NewAllGroup(c.client, c.storage, streamName, c.logger) } return c, nil @@ -61,14 +61,10 @@ type Consumer struct { streamName string initialShardIteratorType string client kinesisiface.KinesisAPI - logger Logger -<<<<<<< HEAD - group Group - checkpoint Checkpoint -======= - storage Storage ->>>>>>> 0162c90... Introduce Storage interface counter Counter + group Group + logger Logger + storage Storage } // ScanFunc is the type of the function called for each message read @@ -124,11 +120,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { // for each record and checkpoints the progress of scan. func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error { // get last seq number from checkpoint -<<<<<<< HEAD lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID) -======= - lastSeqNum, err := c.storage.GetCheckpoint(c.streamName, shardID) ->>>>>>> 0162c90... Introduce Storage interface if err != nil { return fmt.Errorf("get checkpoint error: %v", err) } @@ -173,13 +165,8 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e return err } -<<<<<<< HEAD if err != ErrSkipCheckpoint { if err := c.group.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { -======= - if err != SkipCheckpoint { - if err := c.storage.SetCheckpoint(c.streamName, shardID, *r.SequenceNumber); err != nil { ->>>>>>> 0162c90... Introduce Storage interface return err } } diff --git a/consumer_test.go b/consumer_test.go index 7956456..3151cec 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -229,7 +229,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) { var fn = func(r *Record) error { if aws.StringValue(r.SequenceNumber) == "lastSeqNum" { cancel() - return SkipCheckpoint + return ErrSkipCheckpoint } return nil