Add an in-memory checkpoint to the API (#103)
* Add an in-memory checkpoint to the API * Rename memory to store * Rename test package to store
This commit is contained in:
parent
b87510458e
commit
f85f25c15e
3 changed files with 68 additions and 26 deletions
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
"github.com/aws/aws-sdk-go/aws/request"
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||||
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
||||||
|
|
||||||
|
"github.com/harlow/kinesis-consumer/store/memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
var records = []*kinesis.Record{
|
var records = []*kinesis.Record{
|
||||||
|
|
@ -52,7 +54,7 @@ func TestScan(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
var (
|
var (
|
||||||
cp = &fakeCheckpoint{cache: map[string]string{}}
|
cp = store.New()
|
||||||
ctr = &fakeCounter{}
|
ctr = &fakeCounter{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -114,7 +116,7 @@ func TestScanShard(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
cp = &fakeCheckpoint{cache: map[string]string{}}
|
cp = store.New()
|
||||||
ctr = &fakeCounter{}
|
ctr = &fakeCounter{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -219,7 +221,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var cp = &fakeCheckpoint{cache: map[string]string{}}
|
var cp = store.New()
|
||||||
|
|
||||||
c, err := New("myStreamName", WithClient(client), WithStore(cp))
|
c, err := New("myStreamName", WithClient(client), WithStore(cp))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -335,29 +337,6 @@ func (c *kinesisClientMock) GetShardIteratorWithContext(ctx aws.Context, in *kin
|
||||||
return c.getShardIteratorMock(in)
|
return c.getShardIteratorMock(in)
|
||||||
}
|
}
|
||||||
|
|
||||||
// implementation of checkpoint
|
|
||||||
type fakeCheckpoint struct {
|
|
||||||
cache map[string]string
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fc *fakeCheckpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
|
||||||
fc.mu.Lock()
|
|
||||||
defer fc.mu.Unlock()
|
|
||||||
|
|
||||||
key := fmt.Sprintf("%s-%s", streamName, shardID)
|
|
||||||
fc.cache[key] = sequenceNumber
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fc *fakeCheckpoint) GetCheckpoint(streamName, shardID string) (string, error) {
|
|
||||||
fc.mu.Lock()
|
|
||||||
defer fc.mu.Unlock()
|
|
||||||
|
|
||||||
key := fmt.Sprintf("%s-%s", streamName, shardID)
|
|
||||||
return fc.cache[key], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// implementation of counter
|
// implementation of counter
|
||||||
type fakeCounter struct {
|
type fakeCounter struct {
|
||||||
counter int64
|
counter int64
|
||||||
|
|
|
||||||
33
store/memory/store.go
Normal file
33
store/memory/store.go
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
// The memory store provides a store that can be used for testing and single-threaded applications.
|
||||||
|
// DO NOT USE this in a production application where persistence beyond a single application lifecycle is necessary
|
||||||
|
// or when there are multiple consumers.
|
||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New() *Store {
|
||||||
|
return &Store{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Store struct {
|
||||||
|
sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error {
|
||||||
|
if sequenceNumber == "" {
|
||||||
|
return fmt.Errorf("sequence number should not be empty")
|
||||||
|
}
|
||||||
|
c.Store(streamName+":"+shardID, sequenceNumber)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Store) GetCheckpoint(streamName, shardID string) (string, error) {
|
||||||
|
val, ok := c.Load(streamName + ":" + shardID)
|
||||||
|
if !ok {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return val.(string), nil
|
||||||
|
}
|
||||||
30
store/memory/store_test.go
Normal file
30
store/memory/store_test.go
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_CheckpointLifecycle(t *testing.T) {
|
||||||
|
c := New()
|
||||||
|
|
||||||
|
// set
|
||||||
|
c.SetCheckpoint("streamName", "shardID", "testSeqNum")
|
||||||
|
|
||||||
|
// get
|
||||||
|
val, err := c.GetCheckpoint("streamName", "shardID")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get checkpoint error: %v", err)
|
||||||
|
}
|
||||||
|
if val != "testSeqNum" {
|
||||||
|
t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_SetEmptySeqNum(t *testing.T) {
|
||||||
|
c := New()
|
||||||
|
|
||||||
|
err := c.SetCheckpoint("streamName", "shardID", "")
|
||||||
|
if err == nil || err.Error() != "sequence number should not be empty" {
|
||||||
|
t.Fatalf("should not allow empty sequence number")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue