diff --git a/consumer.go b/consumer.go index 1928161..6e584b6 100644 --- a/consumer.go +++ b/consumer.go @@ -54,30 +54,10 @@ func (c *Consumer) handlerLoop(shardID string, handler Handler) { MaxRecordCount: c.BufferSize, shardID: shardID, } - - params := &kinesis.GetShardIteratorInput{ - ShardId: aws.String(shardID), - StreamName: aws.String(c.StreamName), - } - - if c.Checkpoint.CheckpointExists(shardID) { - params.ShardIteratorType = aws.String("AFTER_SEQUENCE_NUMBER") - params.StartingSequenceNumber = aws.String(c.Checkpoint.SequenceNumber()) - } else { - params.ShardIteratorType = aws.String("TRIM_HORIZON") - } - - resp, err := c.svc.GetShardIterator(params) - if err != nil { - c.Logger.WithError(err).Error("GetShardIterator") - os.Exit(1) - } - - shardIterator := resp.ShardIterator - ctx := c.Logger.WithFields(log.Fields{ "shard": shardID, }) + shardIterator := c.getShardIterator(shardID) ctx.Info("processing") @@ -89,7 +69,9 @@ func (c *Consumer) handlerLoop(shardID string, handler Handler) { ) if err != nil { - log.Fatalf("Error GetRecords %v", err) + ctx.WithError(err).Error("GetRecords") + shardIterator = c.getShardIterator(shardID) + continue } if len(resp.Records) > 0 { @@ -103,11 +85,57 @@ func (c *Consumer) handlerLoop(shardID string, handler Handler) { buf.Flush() } } - } else if resp.NextShardIterator == aws.String("") || shardIterator == resp.NextShardIterator { - c.Logger.Error("NextShardIterator") - os.Exit(1) } - shardIterator = resp.NextShardIterator + if resp.NextShardIterator == nil || shardIterator == resp.NextShardIterator { + shardIterator = c.getShardIterator(shardID) + } else { + shardIterator = resp.NextShardIterator + } } } + +func (c *Consumer) getShardIterator(shardID string) *string { + params := &kinesis.GetShardIteratorInput{ + ShardId: aws.String(shardID), + StreamName: aws.String(c.StreamName), + } + + if c.Checkpoint.CheckpointExists(shardID) { + params.ShardIteratorType = aws.String(string(ShardIteratorAfterSequenceNumber)) + params.StartingSequenceNumber = aws.String(c.Checkpoint.SequenceNumber()) + } else { + params.ShardIteratorType = aws.String(string(c.ShardIteratorType)) + } + + resp, err := c.svc.GetShardIterator(params) + + if err != nil { + c.Logger.WithError(err).Error("GetShardIterator") + os.Exit(1) + } + + return resp.ShardIterator +} + +func (c *Consumer) getShardIterator(shardID string) *string { + params := &kinesis.GetShardIteratorInput{ + ShardId: aws.String(shardID), + StreamName: aws.String(c.StreamName), + } + + if c.Checkpoint.CheckpointExists(shardID) { + params.ShardIteratorType = aws.String(string(ShardIteratorAfterSequenceNumber)) + params.StartingSequenceNumber = aws.String(c.Checkpoint.SequenceNumber()) + } else { + params.ShardIteratorType = aws.String(string(c.ShardIteratorType)) + } + + resp, err := c.svc.GetShardIterator(params) + if err != nil { + c.Logger.WithError(err).Error("GetShardIterator") + os.Exit(1) + } + + return resp.ShardIterator +}