From ba951ff0dab7407a8abed2a9fdb15dbe105b1727 Mon Sep 17 00:00:00 2001 From: Xavi Ramirez Date: Thu, 3 Aug 2017 21:22:52 +0000 Subject: [PATCH] Refactor to fix dead locks and race conditions. --- batchconsumer/writer.go | 66 ++++++--- batchconsumer/writer_test.go | 35 ++--- cmd/consumer/main.go | 17 +-- kcl/kcl.go | 256 +++++++++++++++++++---------------- 4 files changed, 211 insertions(+), 163 deletions(-) diff --git a/batchconsumer/writer.go b/batchconsumer/writer.go index da07f65..bb8e2e6 100644 --- a/batchconsumer/writer.go +++ b/batchconsumer/writer.go @@ -32,7 +32,7 @@ type batchedWriter struct { checkpointMsg chan kcl.SequencePair checkpointShutdown chan struct{} checkpointTag chan string - lastProcessedPair chan kcl.SequencePair + lastIgnoredPair chan kcl.SequencePair batchMsg chan tagMsgPair shutdown chan struct{} @@ -58,11 +58,11 @@ func (b *batchedWriter) Initialize(shardID string, checkpointer kcl.Checkpointer b.checkpointShutdown = make(chan struct{}) b.startCheckpointListener(checkpointer, b.checkpointMsg, b.checkpointShutdown) - b.checkpointTag = make(chan string) + b.checkpointTag = make(chan string, 100) // Buffered to workaround b.batchMsg = make(chan tagMsgPair) b.shutdown = make(chan struct{}) - b.lastProcessedPair = make(chan kcl.SequencePair) - b.startMessageHandler(b.batchMsg, b.checkpointTag, b.lastProcessedPair, b.shutdown) + b.lastIgnoredPair = make(chan kcl.SequencePair) + b.startMessageHandler(b.batchMsg, b.checkpointTag, b.lastIgnoredPair, b.shutdown) return nil } @@ -128,32 +128,48 @@ func (b *batchedWriter) createBatcher(tag string) batcher.Batcher { // startMessageDistributer starts a go-routine that routes messages to batches. It's in uses a // go routine to avoid racey conditions. func (b *batchedWriter) startMessageHandler( - batchMsg <-chan tagMsgPair, checkpointTag <-chan string, lastPair <-chan kcl.SequencePair, + batchMsg <-chan tagMsgPair, checkpointTag <-chan string, lastIgnored <-chan kcl.SequencePair, shutdown <-chan struct{}, ) { - go func() { - var lastProcessedPair kcl.SequencePair - batchers := map[string]batcher.Batcher{} - areBatchersEmpty := true + getBatcher := make(chan string) + rtnBatcher := make(chan batcher.Batcher) + shutdownAdder := make(chan struct{}) + go func() { for { select { case tmp := <-batchMsg: - batcher, ok := batchers[tmp.tag] - if !ok { - batcher = b.createBatcher(tmp.tag) - batchers[tmp.tag] = batcher - } - + getBatcher <- tmp.tag + batcher := <-rtnBatcher err := batcher.AddMessage(tmp.msg, tmp.pair) if err != nil { b.log.ErrorD("add-message", kv.M{ "err": err.Error(), "msg": string(tmp.msg), "tag": tmp.tag, }) } + case <-shutdownAdder: + } + } + }() + + go func() { + var lastIgnoredPair kcl.SequencePair + batchers := map[string]batcher.Batcher{} + areBatchersEmpty := true + + for { + select { + case tag := <-getBatcher: + batcher, ok := batchers[tag] + if !ok { + batcher = b.createBatcher(tag) + batchers[tag] = batcher + } + areBatchersEmpty = false + rtnBatcher <- batcher case tag := <-checkpointTag: - smallest := lastProcessedPair + smallest := lastIgnoredPair isAllEmpty := true for name, batch := range batchers { @@ -166,7 +182,8 @@ func (b *batchedWriter) startMessageHandler( continue } - if pair.IsLessThan(smallest) { + // Check for empty because it's possible that no messages have been ignored + if smallest.IsEmpty() || pair.IsLessThan(smallest) { smallest = pair } @@ -177,17 +194,18 @@ func (b *batchedWriter) startMessageHandler( b.checkpointMsg <- smallest } areBatchersEmpty = isAllEmpty - case pair := <-lastPair: - if areBatchersEmpty { + case pair := <-lastIgnored: + if areBatchersEmpty && !pair.IsEmpty() { b.checkpointMsg <- pair } - lastProcessedPair = pair + lastIgnoredPair = pair case <-shutdown: for _, batch := range batchers { batch.Flush() } - b.checkpointMsg <- lastProcessedPair + b.checkpointMsg <- b.lastProcessedSeq b.checkpointShutdown <- struct{}{} + areBatchersEmpty = true } } @@ -234,6 +252,7 @@ func (b *batchedWriter) ProcessRecords(records []kcl.Record) error { if err != nil { return err } + wasPairIgnored := true for _, rawmsg := range messages { msg, tags, err := b.sender.ProcessMessage(rawmsg) @@ -260,11 +279,14 @@ func (b *batchedWriter) ProcessRecords(records []kcl.Record) error { // sequence number amount all the batch (let's call it A). We then checkpoint at // the A-1 sequence number. b.batchMsg <- tagMsgPair{tag, msg, prevPair} + wasPairIgnored = false } } prevPair = pair - b.lastProcessedPair <- pair + if wasPairIgnored { + b.lastIgnoredPair <- pair + } } b.lastProcessedSeq = pair diff --git a/batchconsumer/writer_test.go b/batchconsumer/writer_test.go index 6622937..d87d1ce 100644 --- a/batchconsumer/writer_test.go +++ b/batchconsumer/writer_test.go @@ -86,30 +86,28 @@ type mockCheckpointer struct { shutdown chan struct{} } -func NewMockCheckpointer(maxSeq string, timeout time.Duration) *mockCheckpointer { +func NewMockCheckpointer(timeout time.Duration) *mockCheckpointer { mcp := &mockCheckpointer{ checkpoint: make(chan string), done: make(chan struct{}, 1), timeout: make(chan struct{}, 1), shutdown: make(chan struct{}), } - mcp.startWaiter(maxSeq, timeout) + mcp.startWaiter(timeout) return mcp } -func (m *mockCheckpointer) startWaiter(maxSeq string, timeout time.Duration) { +func (m *mockCheckpointer) startWaiter(timeout time.Duration) { go func() { for { select { case seq := <-m.checkpoint: m.recievedSequences = append(m.recievedSequences, seq) - if seq == maxSeq { - m.done <- struct{}{} - } case <-time.NewTimer(timeout).C: m.timeout <- struct{}{} case <-m.shutdown: + m.done <- struct{}{} return } } @@ -126,15 +124,10 @@ func (m *mockCheckpointer) wait() error { func (m *mockCheckpointer) Shutdown() { m.shutdown <- struct{}{} } -func (m *mockCheckpointer) Checkpoint(sequenceNumber *string, subSequenceNumber *int) error { - m.checkpoint <- *sequenceNumber +func (m *mockCheckpointer) Checkpoint(pair kcl.SequencePair, retry int) error { + m.checkpoint <- pair.Sequence.String() return nil } -func (m *mockCheckpointer) CheckpointWithRetry( - sequenceNumber *string, subSequenceNumber *int, retryCount int, -) error { - return m.Checkpoint(sequenceNumber, subSequenceNumber) -} func encode(str string) string { return base64.StdEncoding.EncodeToString([]byte(str)) @@ -148,7 +141,7 @@ func TestProcessRecordsIgnoredMessages(t *testing.T) { BatchInterval: 10 * time.Millisecond, CheckpointFreq: 20 * time.Millisecond, }) - mockcheckpointer := NewMockCheckpointer("4", 5*time.Second) + mockcheckpointer := NewMockCheckpointer(5 * time.Second) wrt := NewBatchedWriter(mockconfig, ignoringSender{}, mocklog) wrt.Initialize("test-shard", mockcheckpointer) @@ -161,8 +154,13 @@ func TestProcessRecordsIgnoredMessages(t *testing.T) { }) assert.NoError(err) + err = wrt.Shutdown("TERMINATE") + assert.NoError(err) + err = mockcheckpointer.wait() assert.NoError(err) + + assert.Contains(mockcheckpointer.recievedSequences, "4") } func TestProcessRecordsMutliBatchBasic(t *testing.T) { @@ -173,7 +171,7 @@ func TestProcessRecordsMutliBatchBasic(t *testing.T) { BatchInterval: 100 * time.Millisecond, CheckpointFreq: 200 * time.Millisecond, }) - mockcheckpointer := NewMockCheckpointer("8", 5*time.Second) + mockcheckpointer := NewMockCheckpointer(5 * time.Second) mocksender := NewMsgAsTagSender() wrt := NewBatchedWriter(mockconfig, mocksender, mocklog) @@ -233,7 +231,7 @@ func TestProcessRecordsMutliBatchWithIgnores(t *testing.T) { BatchInterval: 100 * time.Millisecond, CheckpointFreq: 200 * time.Millisecond, }) - mockcheckpointer := NewMockCheckpointer("26", 5*time.Second) + mockcheckpointer := NewMockCheckpointer(5 * time.Second) mocksender := NewMsgAsTagSender() wrt := NewBatchedWriter(mockconfig, mocksender, mocklog) @@ -312,7 +310,7 @@ func TestStaggeredCheckpionting(t *testing.T) { BatchInterval: 100 * time.Millisecond, CheckpointFreq: 200 * time.Nanosecond, }) - mockcheckpointer := NewMockCheckpointer("9", 5*time.Second) + mockcheckpointer := NewMockCheckpointer(5 * time.Second) mocksender := NewMsgAsTagSender() wrt := NewBatchedWriter(mockconfig, mocksender, mocklog) @@ -352,6 +350,7 @@ func TestStaggeredCheckpionting(t *testing.T) { assert.NotContains(mockcheckpointer.recievedSequences, "6") assert.NotContains(mockcheckpointer.recievedSequences, "7") assert.NotContains(mockcheckpointer.recievedSequences, "8") + assert.Contains(mockcheckpointer.recievedSequences, "9") assert.Contains(mocksender.batches, "tag1") assert.Equal(2, len(mocksender.batches["tag1"])) // One batch @@ -365,8 +364,10 @@ func TestStaggeredCheckpionting(t *testing.T) { assert.Equal(2, len(mocksender.batches["tag3"][0])) // with three items assert.Equal("tag3", string(mocksender.batches["tag3"][0][0])) assert.Equal("tag3", string(mocksender.batches["tag3"][0][1])) + assert.Equal(2, len(mocksender.batches["tag3"][1])) assert.Equal("tag3", string(mocksender.batches["tag3"][1][0])) assert.Equal("tag3", string(mocksender.batches["tag3"][1][1])) + assert.Equal(2, len(mocksender.batches["tag3"][2])) assert.Equal("tag3", string(mocksender.batches["tag3"][2][0])) assert.Equal("tag3", string(mocksender.batches["tag3"][2][1])) } diff --git a/cmd/consumer/main.go b/cmd/consumer/main.go index 02281a9..8e51ee9 100644 --- a/cmd/consumer/main.go +++ b/cmd/consumer/main.go @@ -13,8 +13,7 @@ type sampleRecordProcessor struct { checkpointer kcl.Checkpointer checkpointRetries int checkpointFreq time.Duration - largestSeq *big.Int - largestSubSeq int + largestPair kcl.SequencePair lastCheckpoint time.Time } @@ -31,9 +30,8 @@ func (srp *sampleRecordProcessor) Initialize(shardID string, checkpointer kcl.Ch return nil } -func (srp *sampleRecordProcessor) shouldUpdateSequence(sequenceNumber *big.Int, subSequenceNumber int) bool { - return srp.largestSeq == nil || sequenceNumber.Cmp(srp.largestSeq) == 1 || - (sequenceNumber.Cmp(srp.largestSeq) == 0 && subSequenceNumber > srp.largestSubSeq) +func (srp *sampleRecordProcessor) shouldUpdateSequence(pair kcl.SequencePair) bool { + return srp.largestPair.IsLessThan(pair) } func (srp *sampleRecordProcessor) ProcessRecords(records []kcl.Record) error { @@ -43,14 +41,13 @@ func (srp *sampleRecordProcessor) ProcessRecords(records []kcl.Record) error { fmt.Fprintf(os.Stderr, "could not parse sequence number '%s'\n", record.SequenceNumber) continue } - if srp.shouldUpdateSequence(seqNumber, record.SubSequenceNumber) { - srp.largestSeq = seqNumber - srp.largestSubSeq = record.SubSequenceNumber + pair := kcl.SequencePair{seqNumber, record.SubSequenceNumber} + if srp.shouldUpdateSequence(pair) { + srp.largestPair = pair } } if time.Now().Sub(srp.lastCheckpoint) > srp.checkpointFreq { - largestSeq := srp.largestSeq.String() - srp.checkpointer.CheckpointWithRetry(&largestSeq, &srp.largestSubSeq, srp.checkpointRetries) + srp.checkpointer.Checkpoint(srp.largestPair, srp.checkpointRetries) srp.lastCheckpoint = time.Now() } return nil diff --git a/kcl/kcl.go b/kcl/kcl.go index 950a92c..76f6c13 100644 --- a/kcl/kcl.go +++ b/kcl/kcl.go @@ -2,12 +2,10 @@ package kcl import ( "bufio" - "bytes" "encoding/json" "fmt" "io" "os" - "sync" "time" ) @@ -18,8 +16,7 @@ type RecordProcessor interface { } type Checkpointer interface { - Checkpoint(sequenceNumber *string, subSequenceNumber *int) error - CheckpointWithRetry(sequenceNumber *string, subSequenceNumber *int, retryCount int) error + Checkpoint(pair SequencePair, retryCount int) error Shutdown() } @@ -31,93 +28,6 @@ func (ce CheckpointError) Error() string { return ce.e } -type checkpointer struct { - mux sync.Mutex - - ioHandler ioHandler -} - -func (c *checkpointer) getAction() (interface{}, error) { - line, err := c.ioHandler.readLine() - if err != nil { - return nil, err - } - action, err := c.ioHandler.loadAction(line.String()) - if err != nil { - return nil, err - } - return action, nil -} - -func (c *checkpointer) Checkpoint(sequenceNumber *string, subSequenceNumber *int) error { - c.mux.Lock() - defer c.mux.Unlock() - - c.ioHandler.writeAction(ActionCheckpoint{ - Action: "checkpoint", - SequenceNumber: sequenceNumber, - SubSequenceNumber: subSequenceNumber, - }) - line, err := c.ioHandler.readLine() - if err != nil { - return err - } - actionI, err := c.ioHandler.loadAction(line.String()) - if err != nil { - return err - } - action, ok := actionI.(ActionCheckpoint) - if !ok { - return fmt.Errorf("expected checkpoint response, got '%s'", line.String()) - } - if action.Error != nil && *action.Error != "" { - return CheckpointError{ - e: *action.Error, - } - } - return nil -} - -// CheckpointWithRetry tries to save a checkPoint up to `retryCount` + 1 times. -// `retryCount` should be >= 0 -func (c *checkpointer) CheckpointWithRetry( - sequenceNumber *string, subSequenceNumber *int, retryCount int, -) error { - sleepDuration := 5 * time.Second - - for n := 0; n <= retryCount; n++ { - err := c.Checkpoint(sequenceNumber, subSequenceNumber) - if err == nil { - return nil - } - - if cperr, ok := err.(CheckpointError); ok { - switch cperr.Error() { - case "ShutdownException": - return fmt.Errorf("Encountered shutdown exception, skipping checkpoint") - case "ThrottlingException": - fmt.Fprintf(os.Stderr, "Was throttled while checkpointing, will attempt again in %s\n", sleepDuration) - case "InvalidStateException": - fmt.Fprintf(os.Stderr, "MultiLangDaemon reported an invalid state while checkpointing\n") - default: - fmt.Fprintf(os.Stderr, "Encountered an error while checkpointing: %s", err) - } - } - - if n == retryCount { - return fmt.Errorf("Failed to checkpoint after %d attempts, giving up.", retryCount) - } - - time.Sleep(sleepDuration) - } - - return nil -} - -func (c *checkpointer) Shutdown() { - c.CheckpointWithRetry(nil, nil, 5) -} - type ioHandler struct { inputFile io.Reader outputFile io.Writer @@ -134,13 +44,13 @@ func (i ioHandler) writeError(message string) { fmt.Fprintf(i.errorFile, "%s\n", message) } -func (i ioHandler) readLine() (*bytes.Buffer, error) { +func (i ioHandler) readLine() (string, error) { bio := bufio.NewReader(i.inputFile) line, err := bio.ReadString('\n') if err != nil { - return nil, err + return "", err } - return bytes.NewBufferString(line), nil + return line, nil } type ActionInitialize struct { @@ -197,6 +107,8 @@ func (i ioHandler) loadAction(line string) (interface{}, error) { return nil, err } return actionProcessRecords, nil + case "shutdownRequested": + fallthrough case "shutdown": var actionShutdown ActionShutdown if err := json.Unmarshal(lineBytes, &actionShutdown); err != nil { @@ -223,25 +135,37 @@ func (i ioHandler) writeAction(action interface{}) error { return nil } -func New(inputFile io.Reader, outputFile, errorFile io.Writer, recordProcessor RecordProcessor) *KCLProcess { +func New( + inputFile io.Reader, outputFile, errorFile io.Writer, recordProcessor RecordProcessor, +) *KCLProcess { i := ioHandler{ inputFile: inputFile, outputFile: outputFile, errorFile: errorFile, } return &KCLProcess{ - ioHandler: i, - checkpointer: &checkpointer{ - ioHandler: i, - }, + ioHandler: i, recordProcessor: recordProcessor, + + next: make(chan struct{}), + out: make(chan string), + outErr: make(chan error), + + checkpoint: make(chan SequencePair), + checkpointErr: make(chan error), } } type KCLProcess struct { ioHandler ioHandler - checkpointer Checkpointer recordProcessor RecordProcessor + + next chan struct{} + out chan string + outErr chan error + + checkpoint chan SequencePair + checkpointErr chan error } func (kclp *KCLProcess) reportDone(responseFor string) error { @@ -257,13 +181,13 @@ func (kclp *KCLProcess) reportDone(responseFor string) error { func (kclp *KCLProcess) performAction(a interface{}) (string, error) { switch action := a.(type) { case ActionInitialize: - return action.Action, kclp.recordProcessor.Initialize(action.ShardID, kclp.checkpointer) + return action.Action, kclp.recordProcessor.Initialize(action.ShardID, kclp) case ActionProcessRecords: return action.Action, kclp.recordProcessor.ProcessRecords(action.Records) case ActionShutdown: return action.Action, kclp.recordProcessor.Shutdown(action.Reason) default: - return "", fmt.Errorf("unknown action to dispatch: %s", action) + return "", fmt.Errorf("unknown action to dispatch: %+#v", action) } } @@ -280,20 +204,124 @@ func (kclp *KCLProcess) handleLine(line string) error { return kclp.reportDone(responseFor) } -func (kclp *KCLProcess) Run() { - for { - line, err := kclp.ioHandler.readLine() - if err != nil { - kclp.ioHandler.writeError("Read line error: " + err.Error()) - return - } else if line == nil { - kclp.ioHandler.writeError("Empty read line recieved") - return +func (kclp *KCLProcess) Checkpoint(pair SequencePair, retryCount int) error { + sleepDuration := 5 * time.Second + + for n := 0; n <= retryCount; n++ { + kclp.checkpoint <- pair + err := <-kclp.checkpointErr + if err == nil { + return nil } - err = kclp.handleLine(line.String()) - if err != nil { - kclp.ioHandler.writeError("Handle line error: " + err.Error()) + if cperr, ok := err.(CheckpointError); ok { + switch cperr.Error() { + case "ShutdownException": + return fmt.Errorf("Encountered shutdown exception, skipping checkpoint") + case "ThrottlingException": + fmt.Fprintf(os.Stderr, "Checkpointing throttling, pause for %s\n", sleepDuration) + case "InvalidStateException": + fmt.Fprintf(os.Stderr, "MultiLangDaemon invalid state while checkpointing\n") + default: + fmt.Fprintf(os.Stderr, "Encountered an error while checkpointing: %s", err) + } + } + + if n == retryCount { + return fmt.Errorf("Failed to checkpoint after %d attempts, giving up.", retryCount) + } + + time.Sleep(sleepDuration) + } + + return nil +} + +func (kclp *KCLProcess) Shutdown() { + kclp.Checkpoint(SequencePair{}, 5) +} + +func (kclp *KCLProcess) processCheckpoint(pair SequencePair) error { + var seq *string + var subSeq *int + if !pair.IsEmpty() { // an empty pair is a signal to shutdown + tmp := pair.Sequence.String() + seq = &tmp + subSeq = &pair.SubSequence + } + kclp.ioHandler.writeAction(ActionCheckpoint{ + Action: "checkpoint", + SequenceNumber: seq, + SubSequenceNumber: subSeq, + }) + line, err := kclp.ioHandler.readLine() + if err != nil { + return err + } + actionI, err := kclp.ioHandler.loadAction(line) + if err != nil { + return err + } + action, ok := actionI.(ActionCheckpoint) + if !ok { + return fmt.Errorf("expected checkpoint response, got '%s'", line) + } + if action.Error != nil && *action.Error != "" { + return CheckpointError{e: *action.Error} + } + return nil +} + +func (kclp *KCLProcess) startLineProcessor( + next chan struct{}, out chan string, outErr chan error, + checkpoint chan SequencePair, checkpointErr chan error, +) { + go func() { + for { + select { + case <-next: + line, err := kclp.ioHandler.readLine() + if err != nil { + outErr <- err + } else { + out <- line + } + case pair := <-checkpoint: + err := kclp.processCheckpoint(pair) + checkpointErr <- err + } + } + }() +} + +func (kclp *KCLProcess) processNextLine() error { + kclp.next <- struct{}{} // We're ready for a new line + + var err error + var line string + + select { + case err = <-kclp.outErr: + case line = <-kclp.out: + if line == "" { + err = fmt.Errorf("Empty read line recieved") + } else { + err = kclp.handleLine(line) + } + } + + return err +} + +func (kclp *KCLProcess) Run() { + kclp.startLineProcessor(kclp.next, kclp.out, kclp.outErr, kclp.checkpoint, kclp.checkpointErr) + for { + err := kclp.processNextLine() + if err == io.EOF { + kclp.ioHandler.writeError("IO stream closed") + return + } else if err != nil { + kclp.ioHandler.writeError(fmt.Sprintf("ERR Handle line: %+#v", err)) return } }