203 lines
6 KiB
Go
203 lines
6 KiB
Go
|
|
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||
|
|
// SPDX-License-Identifier: Apache-2.0
|
||
|
|
package deaggregator_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/md5"
|
||
|
|
"fmt"
|
||
|
|
"math/rand"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
|
||
|
|
"github.com/golang/protobuf/proto"
|
||
|
|
"github.com/stretchr/testify/assert"
|
||
|
|
|
||
|
|
rec "github.com/awslabs/kinesis-aggregation/go/records"
|
||
|
|
deagg "github.com/harlow/kinesis-consumer/internal/deaggregator"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Generate an aggregate record in the correct AWS-specified format
|
||
|
|
// https://github.com/awslabs/amazon-kinesis-producer/blob/master/aggregation-format.md
|
||
|
|
func generateAggregateRecord(numRecords int) []byte {
|
||
|
|
|
||
|
|
aggr := &rec.AggregatedRecord{}
|
||
|
|
// Start with the magic header
|
||
|
|
aggRecord := []byte("\xf3\x89\x9a\xc2")
|
||
|
|
partKeyTable := make([]string, 0)
|
||
|
|
|
||
|
|
// Create proto record with numRecords length
|
||
|
|
for i := 0; i < numRecords; i++ {
|
||
|
|
var partKey uint64
|
||
|
|
var hashKey uint64
|
||
|
|
partKey = uint64(i)
|
||
|
|
hashKey = uint64(i) * uint64(10)
|
||
|
|
r := &rec.Record{
|
||
|
|
PartitionKeyIndex: &partKey,
|
||
|
|
ExplicitHashKeyIndex: &hashKey,
|
||
|
|
Data: []byte("Some test data string"),
|
||
|
|
Tags: make([]*rec.Tag, 0),
|
||
|
|
}
|
||
|
|
|
||
|
|
aggr.Records = append(aggr.Records, r)
|
||
|
|
partKeyVal := "test" + fmt.Sprint(i)
|
||
|
|
partKeyTable = append(partKeyTable, partKeyVal)
|
||
|
|
}
|
||
|
|
|
||
|
|
aggr.PartitionKeyTable = partKeyTable
|
||
|
|
// Marshal to protobuf record, create md5 sum from proto record
|
||
|
|
// and append both to aggRecord with magic header
|
||
|
|
data, _ := proto.Marshal(aggr)
|
||
|
|
md5Hash := md5.Sum(data)
|
||
|
|
aggRecord = append(aggRecord, data...)
|
||
|
|
aggRecord = append(aggRecord, md5Hash[:]...)
|
||
|
|
return aggRecord
|
||
|
|
}
|
||
|
|
|
||
|
|
// Generate a generic kinesis.Record using whatever []byte
|
||
|
|
// is passed in as the data (can be normal []byte or proto record)
|
||
|
|
func generateKinesisRecord(data []byte) *types.Record {
|
||
|
|
currentTime := time.Now()
|
||
|
|
encryptionType := types.EncryptionTypeNone
|
||
|
|
partitionKey := "1234"
|
||
|
|
sequenceNumber := "21269319989900637946712965403778482371"
|
||
|
|
return &types.Record{
|
||
|
|
ApproximateArrivalTimestamp: ¤tTime,
|
||
|
|
Data: data,
|
||
|
|
EncryptionType: encryptionType,
|
||
|
|
PartitionKey: &partitionKey,
|
||
|
|
SequenceNumber: &sequenceNumber,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// This tests to make sure that the data is at least larger than the length
|
||
|
|
// of the magic header to do some array slicing with index out of bounds
|
||
|
|
func TestSmallLengthReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) {
|
||
|
|
var err error
|
||
|
|
var kr *types.Record
|
||
|
|
|
||
|
|
krs := make([]*types.Record, 0, 1)
|
||
|
|
|
||
|
|
smallByte := []byte("No")
|
||
|
|
kr = generateKinesisRecord(smallByte)
|
||
|
|
krs = append(krs, kr)
|
||
|
|
dars, err := deagg.DeaggregateRecords(krs)
|
||
|
|
if err != nil {
|
||
|
|
panic(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Small byte test, since this is not a deaggregated record, should return 1
|
||
|
|
// record in the array.
|
||
|
|
assert.Equal(t, 1, len(dars), "Small Byte test should return length of 1.")
|
||
|
|
}
|
||
|
|
|
||
|
|
// This function tests to make sure that the data starts with the correct magic header
|
||
|
|
// according to KPL aggregate documentation.
|
||
|
|
func TestNonMatchingMagicHeaderReturnsSingleRecord(t *testing.T) {
|
||
|
|
var err error
|
||
|
|
var kr *types.Record
|
||
|
|
|
||
|
|
krs := make([]*types.Record, 0, 1)
|
||
|
|
|
||
|
|
min := 1
|
||
|
|
max := 10
|
||
|
|
n := rand.Intn(max-min) + min
|
||
|
|
aggData := generateAggregateRecord(n)
|
||
|
|
mismatchAggData := aggData[1:]
|
||
|
|
kr = generateKinesisRecord(mismatchAggData)
|
||
|
|
|
||
|
|
krs = append(krs, kr)
|
||
|
|
|
||
|
|
dars, err := deagg.DeaggregateRecords(krs)
|
||
|
|
if err != nil {
|
||
|
|
panic(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// A byte record with a magic header that does not match 0xF3 0x89 0x9A 0xC2
|
||
|
|
// should return a single record.
|
||
|
|
assert.Equal(t, 1, len(dars), "Mismatch magic header test should return length of 1.")
|
||
|
|
}
|
||
|
|
|
||
|
|
// This function tests that the DeaggregateRecords function returns the correct number of
|
||
|
|
// deaggregated records from a single aggregated record.
|
||
|
|
func TestVariableLengthRecordsReturnsCorrectNumberOfDeaggregatedRecords(t *testing.T) {
|
||
|
|
var err error
|
||
|
|
var kr *types.Record
|
||
|
|
|
||
|
|
krs := make([]*types.Record, 0, 1)
|
||
|
|
|
||
|
|
min := 1
|
||
|
|
max := 10
|
||
|
|
n := rand.Intn(max-min) + min
|
||
|
|
aggData := generateAggregateRecord(n)
|
||
|
|
kr = generateKinesisRecord(aggData)
|
||
|
|
krs = append(krs, kr)
|
||
|
|
|
||
|
|
dars, err := deagg.DeaggregateRecords(krs)
|
||
|
|
if err != nil {
|
||
|
|
panic(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Variable Length Aggregate Record test has aggregaterd records and should return
|
||
|
|
// n length.
|
||
|
|
assertMsg := fmt.Sprintf("Variable Length Aggregate Record should return length %v.", len(dars))
|
||
|
|
assert.Equal(t, n, len(dars), assertMsg)
|
||
|
|
}
|
||
|
|
|
||
|
|
// This function tests the length of the message after magic file header. If length is less than
|
||
|
|
// the digest size (16 bytes), it is not an aggregated record.
|
||
|
|
func TestRecordAfterMagicHeaderWithLengthLessThanDigestSizeReturnsSingleRecord(t *testing.T) {
|
||
|
|
var err error
|
||
|
|
var kr *types.Record
|
||
|
|
|
||
|
|
krs := make([]*types.Record, 0, 1)
|
||
|
|
|
||
|
|
min := 1
|
||
|
|
max := 10
|
||
|
|
n := rand.Intn(max-min) + min
|
||
|
|
aggData := generateAggregateRecord(n)
|
||
|
|
// Change size of proto message to 15
|
||
|
|
reducedAggData := aggData[:19]
|
||
|
|
kr = generateKinesisRecord(reducedAggData)
|
||
|
|
|
||
|
|
krs = append(krs, kr)
|
||
|
|
|
||
|
|
dars, err := deagg.DeaggregateRecords(krs)
|
||
|
|
if err != nil {
|
||
|
|
panic(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// A byte record with length less than 16 after the magic header should return
|
||
|
|
// a single record from DeaggregateRecords
|
||
|
|
assert.Equal(t, 1, len(dars), "Digest size test should return length of 1.")
|
||
|
|
}
|
||
|
|
|
||
|
|
// This function tests the MD5 Sum at the end of the record by comparing MD5 sum
|
||
|
|
// at end of proto record with MD5 Sum of Proto message. If they do not match,
|
||
|
|
// it is not an aggregated record.
|
||
|
|
func TestRecordWithMismatchMd5SumReturnsSingleRecord(t *testing.T) {
|
||
|
|
var err error
|
||
|
|
var kr *types.Record
|
||
|
|
|
||
|
|
krs := make([]*types.Record, 0, 1)
|
||
|
|
|
||
|
|
min := 1
|
||
|
|
max := 10
|
||
|
|
n := rand.Intn(max-min) + min
|
||
|
|
aggData := generateAggregateRecord(n)
|
||
|
|
// Remove last byte from array to mismatch the MD5 sums
|
||
|
|
mismatchAggData := aggData[:len(aggData)-1]
|
||
|
|
kr = generateKinesisRecord(mismatchAggData)
|
||
|
|
|
||
|
|
krs = append(krs, kr)
|
||
|
|
|
||
|
|
dars, err := deagg.DeaggregateRecords(krs)
|
||
|
|
if err != nil {
|
||
|
|
panic(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// A byte record with an MD5 sum that does not match with the md5.Sum(record)
|
||
|
|
// will be marked as a non-aggregate record and return a single record
|
||
|
|
assert.Equal(t, 1, len(dars), "Mismatch md5 sum test should return length of 1.")
|
||
|
|
}
|