diff --git a/batchconsumer/batchermanager.go b/batchconsumer/batchermanager.go index bb1a3fe..b46d37a 100644 --- a/batchconsumer/batchermanager.go +++ b/batchconsumer/batchermanager.go @@ -114,7 +114,7 @@ func (b *batcherManager) sendBatch(batcher *batcher, tag string) { } func (b *batcherManager) sendCheckpoint( - tag string, lastIgnoredPair kcl.SequencePair, batchers map[string]*batcher, + tag string, lastIgnoredPair, lastProcessedPair kcl.SequencePair, batchers map[string]*batcher, ) { smallest := lastIgnoredPair @@ -133,9 +133,10 @@ func (b *batcherManager) sendCheckpoint( } } - if !smallest.IsNil() { - b.chkpntManager.Checkpoint(smallest) + if smallest.IsNil() { // This can occur when all messages in a stream go into one batch + smallest = lastProcessedPair } + b.chkpntManager.Checkpoint(smallest) } // startMessageDistributer starts a go-routine that routes messages to batches. It's in uses a @@ -164,7 +165,7 @@ func (b *batcherManager) startMessageHandler( for tag, batcher := range batchers { if b.batchInterval <= time.Now().Sub(batcher.LastUpdated) { b.sendBatch(batcher, tag) - b.sendCheckpoint(tag, lastIgnoredPair, batchers) + b.sendCheckpoint(tag, lastIgnoredPair, lastProcessedPair, batchers) batcher.Clear() } } @@ -179,7 +180,7 @@ func (b *batcherManager) startMessageHandler( err := batcher.AddMessage(tmp.msg, tmp.pair) if err == ErrBatchFull { b.sendBatch(batcher, tmp.tag) - b.sendCheckpoint(tmp.tag, lastIgnoredPair, batchers) + b.sendCheckpoint(tmp.tag, lastIgnoredPair, lastProcessedPair, batchers) batcher.AddMessage(tmp.msg, tmp.pair) } else if err != nil { diff --git a/batchconsumer/writer_test.go b/batchconsumer/writer_test.go index d62c467..94323f5 100644 --- a/batchconsumer/writer_test.go +++ b/batchconsumer/writer_test.go @@ -162,6 +162,53 @@ func TestProcessRecordsIgnoredMessages(t *testing.T) { assert.Contains(mockcheckpointer.recievedSequences, "4") } +func TestProcessRecordsSingleBatchBasic(t *testing.T) { + assert := assert.New(t) + + mocklog := logger.New("testing") + mockconfig := withDefaults(Config{ + BatchCount: 2, + CheckpointFreq: 1, // Don't throttle checks points + }) + mockcheckpointer := NewMockCheckpointer(5 * time.Second) + mocksender := NewMsgAsTagSender() + + wrt := NewBatchedWriter(mockconfig, mocksender, mocklog) + wrt.Initialize("test-shard", mockcheckpointer) + + err := wrt.ProcessRecords([]kcl.Record{ + kcl.Record{SequenceNumber: "1", Data: encode("tag1")}, + kcl.Record{SequenceNumber: "2", Data: encode("tag1")}, + kcl.Record{SequenceNumber: "3", Data: encode("tag1")}, + kcl.Record{SequenceNumber: "4", Data: encode("tag1")}, + }) + assert.NoError(err) + err = wrt.ProcessRecords([]kcl.Record{ + kcl.Record{SequenceNumber: "5", Data: encode("tag1")}, + kcl.Record{SequenceNumber: "6", Data: encode("tag1")}, + kcl.Record{SequenceNumber: "7", Data: encode("tag1")}, + kcl.Record{SequenceNumber: "8", Data: encode("tag1")}, + }) + assert.NoError(err) + + err = wrt.Shutdown("TERMINATE") + assert.NoError(err) + + err = mockcheckpointer.wait() + assert.NoError(err) + + mocksender.Shutdown() + + assert.Contains(mocksender.batches, "tag1") + assert.Equal(4, len(mocksender.batches["tag1"])) + + assert.Contains(mockcheckpointer.recievedSequences, "2") + assert.Contains(mockcheckpointer.recievedSequences, "4") + assert.Contains(mockcheckpointer.recievedSequences, "6") + assert.Contains(mockcheckpointer.recievedSequences, "8") + +} + func TestProcessRecordsMutliBatchBasic(t *testing.T) { assert := assert.New(t)