diff --git a/README.md b/README.md index 854a007..f6ba07d 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,14 @@ func main() { ) flag.Parse() + // create new consumer c := connector.NewConsumer(*app, *stream) + + // override default values + c.Set("maxBatchCount", 200) + c.Set("pollInterval", "3s") + + // start consuming records from the queues c.Start(connector.HandlerFunc(func(b connector.Buffer) { fmt.Println(b.GetRecords()) // process the records diff --git a/buffer.go b/buffer.go index 4a0b689..819d863 100644 --- a/buffer.go +++ b/buffer.go @@ -9,7 +9,7 @@ type Buffer struct { firstSequenceNumber string lastSequenceNumber string - MaxBufferSize int + MaxBatchCount int } // AddRecord adds a record to the buffer. @@ -24,7 +24,7 @@ func (b *Buffer) AddRecord(r *kinesis.Record) { // ShouldFlush determines if the buffer has reached its target size. func (b *Buffer) ShouldFlush() bool { - return len(b.records) >= b.MaxBufferSize + return len(b.records) >= b.MaxBatchCount } // Flush empties the buffer and resets the sequence counter. diff --git a/buffer_test.go b/buffer_test.go index d11da71..a6d40d0 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -34,7 +34,7 @@ func Test_LastSeq(t *testing.T) { } func Test_ShouldFlush(t *testing.T) { - b := Buffer{MaxBufferSize: 2} + b := Buffer{MaxBatchCount: 2} s1, s2 := "1", "2" r1 := &kinesis.Record{SequenceNumber: &s1} r2 := &kinesis.Record{SequenceNumber: &s2} diff --git a/consumer.go b/consumer.go index 08a8b67..8348c26 100644 --- a/consumer.go +++ b/consumer.go @@ -10,8 +10,13 @@ import ( "github.com/aws/aws-sdk-go/service/kinesis" ) -const maxBufferSize = 400 +var ( + pollInterval = 1 * time.Second + maxBatchCount = 1000 +) +// NewConsumer creates a new kinesis connection and returns a +// new consumer initialized with app and stream name func NewConsumer(appName, streamName string) *Consumer { svc := kinesis.New(session.New()) @@ -28,6 +33,25 @@ type Consumer struct { svc *kinesis.Kinesis } +// Set `option` to `value` +func (c *Consumer) Set(option string, value interface{}) { + var err error + + switch option { + case "maxBatchCount": + maxBatchCount = value.(int) + case "pollInterval": + pollInterval, err = time.ParseDuration(value.(string)) + if err != nil { + logger.Log("fatal", "ParseDuration", "msg", "unable to parse pollInterval value") + os.Exit(1) + } + default: + logger.Log("fatal", "Set", "msg", "unknown option") + os.Exit(1) + } +} + func (c *Consumer) Start(handler Handler) { params := &kinesis.DescribeStreamInput{ StreamName: aws.String(c.streamName), @@ -69,8 +93,8 @@ func (c *Consumer) handlerLoop(shardID string, handler Handler) { } } + b := &Buffer{MaxBatchCount: maxBatchCount} shardIterator := resp.ShardIterator - b := &Buffer{MaxBufferSize: maxBufferSize} errCount := 0 for { @@ -110,7 +134,8 @@ func (c *Consumer) handlerLoop(shardID string, handler Handler) { logger.Log("fatal", "nextShardIterator", "msg", err.Error()) os.Exit(1) } else { - time.Sleep(1 * time.Second) + logger.Log("info", "sleeping", "msg", "no records to process") + time.Sleep(pollInterval) } shardIterator = resp.NextShardIterator diff --git a/consumer_test.go b/consumer_test.go new file mode 100644 index 0000000..84ac83c --- /dev/null +++ b/consumer_test.go @@ -0,0 +1,17 @@ +package connector + +import ( + "testing" + + "github.com/bmizerany/assert" +) + +func Test_Set(t *testing.T) { + defaultMaxBatchCount := 1000 + assert.Equal(t, maxBatchCount, defaultMaxBatchCount) + + c := NewConsumer("app", "stream") + c.Set("maxBatchCount", 100) + + assert.Equal(t, maxBatchCount, 100) +} diff --git a/examples/firehose/main.go b/examples/firehose/main.go index 90d06fd..b16cbd9 100644 --- a/examples/firehose/main.go +++ b/examples/firehose/main.go @@ -32,6 +32,8 @@ func main() { svc := firehose.New(session.New()) c := connector.NewConsumer(*app, *stream) + c.Set("maxBatchCount", 400) + c.Set("pollInterval", "3s") c.Start(connector.HandlerFunc(func(b connector.Buffer) { params := &firehose.PutRecordBatchInput{ DeliveryStreamName: aws.String(*delivery),