diff --git a/.gitignore b/.gitignore index fc5ded3..d379f71 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,6 @@ prof.mem # VSCode files /.vscode /**/debug + +# Goland files +.idea/ \ No newline at end of file diff --git a/client.go b/client.go index 40e840e..5d9e67a 100644 --- a/client.go +++ b/client.go @@ -20,7 +20,7 @@ func WithKinesis(svc kinesisiface.KinesisAPI) ClientOption { } } -// WithStartFrmLatest will make sure the client start consuming +// WithStartFromLatest will make sure the client start consuming // events starting from the most recent event in kinesis. This // option discards the checkpoints. func WithStartFromLatest() ClientOption { @@ -30,18 +30,21 @@ func WithStartFromLatest() ClientOption { } // NewKinesisClient returns client to interface with Kinesis stream -func NewKinesisClient(opts ...ClientOption) *KinesisClient { +func NewKinesisClient(opts ...ClientOption) (*KinesisClient, error) { kc := &KinesisClient{} for _, opt := range opts { opt(kc) } - + newSession, err := session.NewSession(aws.NewConfig()) + if err != nil { + return nil, err + } if kc.svc == nil { - kc.svc = kinesis.New(session.New(aws.NewConfig())) + kc.svc = kinesis.New(newSession) } - return kc + return kc, nil } // KinesisClient acts as wrapper around Kinesis client @@ -61,7 +64,7 @@ func (c *KinesisClient) GetShardIDs(streamName string) ([]string, error) { return nil, fmt.Errorf("describe stream error: %v", err) } - ss := []string{} + var ss []string for _, shard := range resp.StreamDescription.Shards { ss = append(ss, *shard.ShardId) } diff --git a/consumer.go b/consumer.go index 75e685a..2172027 100644 --- a/consumer.go +++ b/consumer.go @@ -87,13 +87,18 @@ func New(streamName string, opts ...Option) (*Consumer, error) { return nil, fmt.Errorf("must provide stream name") } + kc, err := NewKinesisClient() + if err != nil { + return nil, err + } + // new consumer with no-op checkpoint, counter, and logger c := &Consumer{ streamName: streamName, checkpoint: &noopCheckpoint{}, counter: &noopCounter{}, logger: NewDefaultLogger(), - client: NewKinesisClient(), + client: kc, } // override defaults @@ -178,8 +183,6 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn func(*Recor // loop records for r := range recc { scanError := fn(r) - // It will be nicer if this can be reported with checkpoint error - err = scanError.Error // Skip invalid state if scanError.StopScan && scanError.SkipCheckpoint { diff --git a/consumer_test.go b/consumer_test.go index 25e8e53..aaddf5c 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -3,11 +3,11 @@ package consumer import ( "context" "fmt" - "io/ioutil" - "log" "sync" "testing" + "errors" + "github.com/aws/aws-sdk-go/aws" ) @@ -39,7 +39,7 @@ func TestScanShard(t *testing.T) { client: client, checkpoint: ckp, counter: ctr, - logger: log.New(ioutil.Discard, "", log.LstdFlags), + logger: NewDefaultLogger(), } // callback fn simply appends the record data to result string