From f85f25c15e52a29cfa4a9499837dbbc25e9e7683 Mon Sep 17 00:00:00 2001 From: Andrew Shannon Brown Date: Sun, 8 Sep 2019 13:13:04 -0700 Subject: [PATCH] Add an in-memory checkpoint to the API (#103) * Add an in-memory checkpoint to the API * Rename memory to store * Rename test package to store --- consumer_test.go | 31 +++++-------------------------- store/memory/store.go | 33 +++++++++++++++++++++++++++++++++ store/memory/store_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 26 deletions(-) create mode 100644 store/memory/store.go create mode 100644 store/memory/store_test.go diff --git a/consumer_test.go b/consumer_test.go index 2ec2ec4..f1f9cca 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -11,6 +11,8 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + + "github.com/harlow/kinesis-consumer/store/memory" ) var records = []*kinesis.Record{ @@ -52,7 +54,7 @@ func TestScan(t *testing.T) { }, } var ( - cp = &fakeCheckpoint{cache: map[string]string{}} + cp = store.New() ctr = &fakeCounter{} ) @@ -114,7 +116,7 @@ func TestScanShard(t *testing.T) { } var ( - cp = &fakeCheckpoint{cache: map[string]string{}} + cp = store.New() ctr = &fakeCounter{} ) @@ -219,7 +221,7 @@ func TestScanShard_SkipCheckpoint(t *testing.T) { }, } - var cp = &fakeCheckpoint{cache: map[string]string{}} + var cp = store.New() c, err := New("myStreamName", WithClient(client), WithStore(cp)) if err != nil { @@ -335,29 +337,6 @@ func (c *kinesisClientMock) GetShardIteratorWithContext(ctx aws.Context, in *kin return c.getShardIteratorMock(in) } -// implementation of checkpoint -type fakeCheckpoint struct { - cache map[string]string - mu sync.Mutex -} - -func (fc *fakeCheckpoint) SetCheckpoint(streamName, shardID, sequenceNumber string) error { - fc.mu.Lock() - defer fc.mu.Unlock() - - key := fmt.Sprintf("%s-%s", streamName, shardID) - fc.cache[key] = sequenceNumber - return nil -} - -func (fc *fakeCheckpoint) GetCheckpoint(streamName, shardID string) (string, error) { - fc.mu.Lock() - defer fc.mu.Unlock() - - key := fmt.Sprintf("%s-%s", streamName, shardID) - return fc.cache[key], nil -} - // implementation of counter type fakeCounter struct { counter int64 diff --git a/store/memory/store.go b/store/memory/store.go new file mode 100644 index 0000000..e111ec4 --- /dev/null +++ b/store/memory/store.go @@ -0,0 +1,33 @@ +// The memory store provides a store that can be used for testing and single-threaded applications. +// DO NOT USE this in a production application where persistence beyond a single application lifecycle is necessary +// or when there are multiple consumers. +package store + +import ( + "fmt" + "sync" +) + +func New() *Store { + return &Store{} +} + +type Store struct { + sync.Map +} + +func (c *Store) SetCheckpoint(streamName, shardID, sequenceNumber string) error { + if sequenceNumber == "" { + return fmt.Errorf("sequence number should not be empty") + } + c.Store(streamName+":"+shardID, sequenceNumber) + return nil +} + +func (c *Store) GetCheckpoint(streamName, shardID string) (string, error) { + val, ok := c.Load(streamName + ":" + shardID) + if !ok { + return "", nil + } + return val.(string), nil +} diff --git a/store/memory/store_test.go b/store/memory/store_test.go new file mode 100644 index 0000000..6b05bc7 --- /dev/null +++ b/store/memory/store_test.go @@ -0,0 +1,30 @@ +package store + +import ( + "testing" +) + +func Test_CheckpointLifecycle(t *testing.T) { + c := New() + + // set + c.SetCheckpoint("streamName", "shardID", "testSeqNum") + + // get + val, err := c.GetCheckpoint("streamName", "shardID") + if err != nil { + t.Fatalf("get checkpoint error: %v", err) + } + if val != "testSeqNum" { + t.Fatalf("checkpoint exists expected %s, got %s", "testSeqNum", val) + } +} + +func Test_SetEmptySeqNum(t *testing.T) { + c := New() + + err := c.SetCheckpoint("streamName", "shardID", "") + if err == nil || err.Error() != "sequence number should not be empty" { + t.Fatalf("should not allow empty sequence number") + } +}