Adjust the scan shard tests

This commit is contained in:
Harlow Ward 2019-04-08 11:16:09 -07:00
parent 7e72723168
commit dbab92f317
3 changed files with 30 additions and 54 deletions

View file

@ -33,11 +33,10 @@ type broker struct {
shards map[string]*kinesis.Shard
}
func (b *broker) shardLoop(ctx context.Context) {
func (b *broker) pollShards(ctx context.Context) {
b.fetchShards()
// add ticker, and cancellation
// also add signal to re-pull?
// TODO: also add signal to re-poll
go func() {
for {

View file

@ -88,7 +88,7 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
defer cancel()
go func() {
broker.shardLoop(ctx)
broker.pollShards(ctx)
<-ctx.Done()
close(shardc)

View file

@ -96,29 +96,6 @@ func TestScan(t *testing.T) {
}
}
func TestScan_NoShardsAvailable(t *testing.T) {
client := &kinesisClientMock{
listShardsMock: func(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: make([]*kinesis.Shard, 0),
}, nil
},
}
var fn = func(r *Record) error {
return nil
}
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
if err := c.Scan(context.Background(), fn); err == nil {
t.Errorf("scan shard error expected not nil. got %v", err)
}
}
func TestScanShard(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
@ -164,7 +141,7 @@ func TestScanShard(t *testing.T) {
return nil
}
if err := c.Scan(ctx, fn); err != nil {
if err := c.ScanShard(ctx, "myShard", fn); err != nil {
t.Errorf("scan returned unexpected error %v", err)
}
@ -269,35 +246,35 @@ func TestScanShard_SkipCheckpoint(t *testing.T) {
}
}
// func TestScanShard_ShardIsClosed(t *testing.T) {
// var client = &kinesisClientMock{
// getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
// return &kinesis.GetShardIteratorOutput{
// ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
// }, nil
// },
// getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
// return &kinesis.GetRecordsOutput{
// NextShardIterator: nil,
// Records: make([]*Record, 0),
// }, nil
// },
// }
func TestScanShard_ShardIsClosed(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(input *kinesis.GetShardIteratorInput) (*kinesis.GetShardIteratorOutput, error) {
return &kinesis.GetShardIteratorOutput{
ShardIterator: aws.String("49578481031144599192696750682534686652010819674221576194"),
}, nil
},
getRecordsMock: func(input *kinesis.GetRecordsInput) (*kinesis.GetRecordsOutput, error) {
return &kinesis.GetRecordsOutput{
NextShardIterator: nil,
Records: make([]*Record, 0),
}, nil
},
}
// c, err := New("myStreamName", WithClient(client))
// if err != nil {
// t.Fatalf("new consumer error: %v", err)
// }
c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}
// var fn = func(r *Record) error {
// return nil
// }
var fn = func(r *Record) error {
return nil
}
// err = c.ScanShard(context.Background(), "myShard", fn)
// if err != nil {
// t.Fatalf("scan shard error: %v", err)
// }
// }
err = c.ScanShard(context.Background(), "myShard", fn)
if err != nil {
t.Fatalf("scan shard error: %v", err)
}
}
type kinesisClientMock struct {
kinesisiface.KinesisAPI