diff --git a/.gitignore b/.gitignore index f2a12a4..4df48e8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +vendor + # osx / sshfs ._* .DS_Store diff --git a/cmd/consumer/main.go b/cmd/consumer/main.go index 4602c13..22f8ccc 100644 --- a/cmd/consumer/main.go +++ b/cmd/consumer/main.go @@ -10,7 +10,7 @@ import ( ) type SampleRecordProcessor struct { - sleepDuration time.Duration + checkpointer *kcl.Checkpointer checkpointRetries int checkpointFreq time.Duration largestSeq *big.Int @@ -20,53 +20,23 @@ type SampleRecordProcessor struct { func New() *SampleRecordProcessor { return &SampleRecordProcessor{ - sleepDuration: 5 * time.Second, checkpointRetries: 5, checkpointFreq: 60 * time.Second, } } -func (srp *SampleRecordProcessor) Initialize(shardID string) error { +func (srp *SampleRecordProcessor) Initialize(shardID string, checkpointer *kcl.Checkpointer) error { srp.lastCheckpoint = time.Now() + srp.checkpointer = checkpointer return nil } -func (srp *SampleRecordProcessor) checkpoint(checkpointer kcl.Checkpointer, sequenceNumber *string, subSequenceNumber *int) { - for n := -1; n < srp.checkpointRetries; n++ { - err := checkpointer.Checkpoint(sequenceNumber, subSequenceNumber) - if err == nil { - return - } - - if cperr, ok := err.(kcl.CheckpointError); ok { - switch cperr.Error() { - case "ShutdownException": - fmt.Fprintf(os.Stderr, "Encountered shutdown exception, skipping checkpoint\n") - return - case "ThrottlingException": - fmt.Fprintf(os.Stderr, "Was throttled while checkpointing, will attempt again in %s", srp.sleepDuration) - case "InvalidStateException": - fmt.Fprintf(os.Stderr, "MultiLangDaemon reported an invalid state while checkpointing\n") - default: - fmt.Fprintf(os.Stderr, "Encountered an error while checkpointing: %s", err) - } - } - - if n == srp.checkpointRetries { - fmt.Fprintf(os.Stderr, "Failed to checkpoint after %d attempts, giving up.\n", srp.checkpointRetries) - return - } - - time.Sleep(srp.sleepDuration) - } -} - func (srp *SampleRecordProcessor) shouldUpdateSequence(sequenceNumber *big.Int, subSequenceNumber int) bool { return srp.largestSeq == nil || sequenceNumber.Cmp(srp.largestSeq) == 1 || (sequenceNumber.Cmp(srp.largestSeq) == 0 && subSequenceNumber > srp.largestSubSeq) } -func (srp *SampleRecordProcessor) ProcessRecords(records []kcl.Record, checkpointer kcl.Checkpointer) error { +func (srp *SampleRecordProcessor) ProcessRecords(records []kcl.Record) error { for _, record := range records { seqNumber := new(big.Int) if _, ok := seqNumber.SetString(record.SequenceNumber, 10); !ok { @@ -80,16 +50,16 @@ func (srp *SampleRecordProcessor) ProcessRecords(records []kcl.Record, checkpoin } if time.Now().Sub(srp.lastCheckpoint) > srp.checkpointFreq { largestSeq := srp.largestSeq.String() - srp.checkpoint(checkpointer, &largestSeq, &srp.largestSubSeq) + srp.checkpointer.CheckpointWithRetry(&largestSeq, &srp.largestSubSeq, srp.checkpointRetries) srp.lastCheckpoint = time.Now() } return nil } -func (srp *SampleRecordProcessor) Shutdown(checkpointer kcl.Checkpointer, reason string) error { +func (srp *SampleRecordProcessor) Shutdown(reason string) error { if reason == "TERMINATE" { fmt.Fprintf(os.Stderr, "Was told to terminate, will attempt to checkpoint.\n") - srp.checkpoint(checkpointer, nil, nil) + srp.checkpointer.Shutdown() } else { fmt.Fprintf(os.Stderr, "Shutting down due to failover. Will not checkpoint.\n") } diff --git a/kcl/kcl.go b/kcl/kcl.go index a9f4d9c..1ccd333 100644 --- a/kcl/kcl.go +++ b/kcl/kcl.go @@ -6,19 +6,32 @@ import ( "encoding/json" "fmt" "io" + "os" + "sync" + "time" ) type RecordProcessor interface { - Initialize(shardID string) error - ProcessRecords(records []Record, checkpointer Checkpointer) error - Shutdown(checkpointer Checkpointer, reason string) error + Initialize(shardID string, checkpointer *Checkpointer) error + ProcessRecords(records []Record) error + Shutdown(reason string) error +} + +type CheckpointError struct { + e string +} + +func (ce CheckpointError) Error() string { + return ce.e } type Checkpointer struct { + mux sync.Mutex + ioHandler ioHandler } -func (c Checkpointer) getAction() (interface{}, error) { +func (c *Checkpointer) getAction() (interface{}, error) { line, err := c.ioHandler.readLine() if err != nil { return nil, err @@ -30,15 +43,10 @@ func (c Checkpointer) getAction() (interface{}, error) { return action, nil } -type CheckpointError struct { - e string -} +func (c *Checkpointer) Checkpoint(sequenceNumber *string, subSequenceNumber *int) error { + c.mux.Lock() + defer c.mux.Unlock() -func (ce CheckpointError) Error() string { - return ce.e -} - -func (c Checkpointer) Checkpoint(sequenceNumber *string, subSequenceNumber *int) error { c.ioHandler.writeAction(ActionCheckpoint{ Action: "checkpoint", SequenceNumber: sequenceNumber, @@ -62,7 +70,46 @@ func (c Checkpointer) Checkpoint(sequenceNumber *string, subSequenceNumber *int) } } return nil +} +// CheckpointWithRetry tries to save a checkPoint up to `retryCount` + 1 times. +// `retryCount` should be >= 0 +func (c *Checkpointer) CheckpointWithRetry( + sequenceNumber *string, subSequenceNumber *int, retryCount int, +) error { + sleepDuration := 5 * time.Second + + for n := 0; n <= retryCount; n++ { + err := c.Checkpoint(sequenceNumber, subSequenceNumber) + if err == nil { + return nil + } + + if cperr, ok := err.(CheckpointError); ok { + switch cperr.Error() { + case "ShutdownException": + return fmt.Errorf("Encountered shutdown exception, skipping checkpoint") + case "ThrottlingException": + fmt.Fprintf(os.Stderr, "Was throttled while checkpointing, will attempt again in %s\n", sleepDuration) + case "InvalidStateException": + fmt.Fprintf(os.Stderr, "MultiLangDaemon reported an invalid state while checkpointing\n") + default: + fmt.Fprintf(os.Stderr, "Encountered an error while checkpointing: %s", err) + } + } + + if n == retryCount { + return fmt.Errorf("Failed to checkpoint after %d attempts, giving up.", retryCount) + } + + time.Sleep(sleepDuration) + } + + return nil +} + +func (c *Checkpointer) Shutdown() { + c.CheckpointWithRetry(nil, nil, 5) } type ioHandler struct { @@ -178,7 +225,7 @@ func New(inputFile io.Reader, outputFile, errorFile io.Writer, recordProcessor R } return &KCLProcess{ ioHandler: i, - checkpointer: Checkpointer{ + checkpointer: &Checkpointer{ ioHandler: i, }, recordProcessor: recordProcessor, @@ -187,7 +234,7 @@ func New(inputFile io.Reader, outputFile, errorFile io.Writer, recordProcessor R type KCLProcess struct { ioHandler ioHandler - checkpointer Checkpointer + checkpointer *Checkpointer recordProcessor RecordProcessor } @@ -204,11 +251,11 @@ func (kclp *KCLProcess) reportDone(responseFor string) error { func (kclp *KCLProcess) performAction(a interface{}) (string, error) { switch action := a.(type) { case ActionInitialize: - return action.Action, kclp.recordProcessor.Initialize(action.ShardID) + return action.Action, kclp.recordProcessor.Initialize(action.ShardID, kclp.checkpointer) case ActionProcessRecords: - return action.Action, kclp.recordProcessor.ProcessRecords(action.Records, kclp.checkpointer) + return action.Action, kclp.recordProcessor.ProcessRecords(action.Records) case ActionShutdown: - return action.Action, kclp.recordProcessor.Shutdown(kclp.checkpointer, action.Reason) + return action.Action, kclp.recordProcessor.Shutdown(action.Reason) default: return "", fmt.Errorf("unknown action to dispatch: %s", action) }