diff --git a/cmd/consumer/main.go b/cmd/consumer/main.go index 4602c13..860a9ca 100644 --- a/cmd/consumer/main.go +++ b/cmd/consumer/main.go @@ -10,6 +10,7 @@ import ( ) type SampleRecordProcessor struct { + checkpointer *kcl.Checkpointer sleepDuration time.Duration checkpointRetries int checkpointFreq time.Duration @@ -26,14 +27,15 @@ func New() *SampleRecordProcessor { } } -func (srp *SampleRecordProcessor) Initialize(shardID string) error { +func (srp *SampleRecordProcessor) Initialize(shardID string, checkpointer *kcl.Checkpointer) error { srp.lastCheckpoint = time.Now() + srp.checkpointer = checkpointer return nil } -func (srp *SampleRecordProcessor) checkpoint(checkpointer kcl.Checkpointer, sequenceNumber *string, subSequenceNumber *int) { - for n := -1; n < srp.checkpointRetries; n++ { - err := checkpointer.Checkpoint(sequenceNumber, subSequenceNumber) +func (srp *SampleRecordProcessor) checkpoint(sequenceNumber *string, subSequenceNumber *int) { + for n := 0; n < srp.checkpointRetries+1; n++ { + err := srp.checkpointer.Checkpoint(sequenceNumber, subSequenceNumber) if err == nil { return } @@ -66,7 +68,7 @@ func (srp *SampleRecordProcessor) shouldUpdateSequence(sequenceNumber *big.Int, (sequenceNumber.Cmp(srp.largestSeq) == 0 && subSequenceNumber > srp.largestSubSeq) } -func (srp *SampleRecordProcessor) ProcessRecords(records []kcl.Record, checkpointer kcl.Checkpointer) error { +func (srp *SampleRecordProcessor) ProcessRecords(records []kcl.Record) error { for _, record := range records { seqNumber := new(big.Int) if _, ok := seqNumber.SetString(record.SequenceNumber, 10); !ok { @@ -80,16 +82,16 @@ func (srp *SampleRecordProcessor) ProcessRecords(records []kcl.Record, checkpoin } if time.Now().Sub(srp.lastCheckpoint) > srp.checkpointFreq { largestSeq := srp.largestSeq.String() - srp.checkpoint(checkpointer, &largestSeq, &srp.largestSubSeq) + srp.checkpoint(&largestSeq, &srp.largestSubSeq) srp.lastCheckpoint = time.Now() } return nil } -func (srp *SampleRecordProcessor) Shutdown(checkpointer kcl.Checkpointer, reason string) error { +func (srp *SampleRecordProcessor) Shutdown(reason string) error { if reason == "TERMINATE" { fmt.Fprintf(os.Stderr, "Was told to terminate, will attempt to checkpoint.\n") - srp.checkpoint(checkpointer, nil, nil) + srp.checkpoint(nil, nil) } else { fmt.Fprintf(os.Stderr, "Shutting down due to failover. Will not checkpoint.\n") } diff --git a/kcl/kcl.go b/kcl/kcl.go index a9f4d9c..1933a68 100644 --- a/kcl/kcl.go +++ b/kcl/kcl.go @@ -9,16 +9,24 @@ import ( ) type RecordProcessor interface { - Initialize(shardID string) error - ProcessRecords(records []Record, checkpointer Checkpointer) error - Shutdown(checkpointer Checkpointer, reason string) error + Initialize(shardID string, checkpointer *Checkpointer) error + ProcessRecords(records []Record) error + Shutdown(reason string) error +} + +type CheckpointError struct { + e string +} + +func (ce CheckpointError) Error() string { + return ce.e } type Checkpointer struct { ioHandler ioHandler } -func (c Checkpointer) getAction() (interface{}, error) { +func (c *Checkpointer) getAction() (interface{}, error) { line, err := c.ioHandler.readLine() if err != nil { return nil, err @@ -30,15 +38,7 @@ func (c Checkpointer) getAction() (interface{}, error) { return action, nil } -type CheckpointError struct { - e string -} - -func (ce CheckpointError) Error() string { - return ce.e -} - -func (c Checkpointer) Checkpoint(sequenceNumber *string, subSequenceNumber *int) error { +func (c *Checkpointer) Checkpoint(sequenceNumber *string, subSequenceNumber *int) error { c.ioHandler.writeAction(ActionCheckpoint{ Action: "checkpoint", SequenceNumber: sequenceNumber, @@ -178,7 +178,7 @@ func New(inputFile io.Reader, outputFile, errorFile io.Writer, recordProcessor R } return &KCLProcess{ ioHandler: i, - checkpointer: Checkpointer{ + checkpointer: &Checkpointer{ ioHandler: i, }, recordProcessor: recordProcessor, @@ -187,7 +187,7 @@ func New(inputFile io.Reader, outputFile, errorFile io.Writer, recordProcessor R type KCLProcess struct { ioHandler ioHandler - checkpointer Checkpointer + checkpointer *Checkpointer recordProcessor RecordProcessor } @@ -204,11 +204,11 @@ func (kclp *KCLProcess) reportDone(responseFor string) error { func (kclp *KCLProcess) performAction(a interface{}) (string, error) { switch action := a.(type) { case ActionInitialize: - return action.Action, kclp.recordProcessor.Initialize(action.ShardID) + return action.Action, kclp.recordProcessor.Initialize(action.ShardID, kclp.checkpointer) case ActionProcessRecords: - return action.Action, kclp.recordProcessor.ProcessRecords(action.Records, kclp.checkpointer) + return action.Action, kclp.recordProcessor.ProcessRecords(action.Records) case ActionShutdown: - return action.Action, kclp.recordProcessor.Shutdown(kclp.checkpointer, action.Reason) + return action.Action, kclp.recordProcessor.Shutdown(action.Reason) default: return "", fmt.Errorf("unknown action to dispatch: %s", action) }