diff --git a/internal/family.go b/internal/family.go index 45fb4df..e1f48cc 100644 --- a/internal/family.go +++ b/internal/family.go @@ -35,6 +35,7 @@ type families struct { ReservoirItems family VarOptItems family ReservoirUnion family + REQ family } var FamilyEnum = &families{ @@ -86,4 +87,8 @@ var FamilyEnum = &families{ Id: 12, MaxPreLongs: 1, }, + REQ: family{ + Id: 17, + MaxPreLongs: 2, + }, } diff --git a/req/compactor.go b/req/compactor.go index df789dd..43f2bd9 100644 --- a/req/compactor.go +++ b/req/compactor.go @@ -21,6 +21,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "math" "math/bits" "math/rand/v2" @@ -85,32 +86,6 @@ func copyCompactor(other *compactor) *compactor { } } -func reconstructCompactor( - lgWeight byte, - hra bool, - state int64, - sectionSizeFlt float32, - numSections byte, - delta int, - sorted bool, - items []float32, - count int, -) *compactor { - return &compactor{ - lgWeight: lgWeight, - isHighRankAccuracyMode: hra, - sectionSizeFlt: sectionSizeFlt, - numSections: numSections, - state: state, - coin: false, - delta: delta, - sorted: sorted, - items: items, - count: count, - sectionSize: nearestEven(sectionSizeFlt), - } -} - func numberOfTrailingOnes(v int64) int { return bits.TrailingZeros64(uint64(^v)) } @@ -247,7 +222,7 @@ func nearestEven(val float32) int { return int(math.Round(float64(val)/2.0)) << 1 } -func (c *compactor) Marshal() ([]byte, error) { +func (c *compactor) MarshalBinary() ([]byte, error) { size := c.SerializationBytes() arr := make([]byte, size) offset := 0 @@ -561,3 +536,190 @@ func (c *compactor) ensureCapacity(newCapacity int) { c.items = out } } + +type compactorDecodingResult struct { + compactor *compactor + bufferEndIndex int + minItem float32 + maxItem float32 + n int64 +} + +func decodeCompactor( + buf []byte, index int, isLevel0Sorted, isHighRankAccuracyMode bool, +) (compactorDecodingResult, error) { // the second returned value is the end index of after decoding. + if err := validateBuffer(buf, index+8); err != nil { + return compactorDecodingResult{}, err + } + state := binary.LittleEndian.Uint64(buf[index : index+8]) + index += 8 + + if err := validateBuffer(buf, index+4); err != nil { + return compactorDecodingResult{}, err + } + sectionSizeFloat := math.Float32frombits(binary.LittleEndian.Uint32(buf[index : index+4])) + sectionSize := math.Round(float64(sectionSizeFloat)) + index += 4 + + if err := validateBuffer(buf, index+1); err != nil { + return compactorDecodingResult{}, err + } + lgWeight := buf[index] + index++ + + if err := validateBuffer(buf, index+1); err != nil { + return compactorDecodingResult{}, err + } + numSections := buf[index] + index++ + + index += 2 // pad + + if err := validateBuffer(buf, index+4); err != nil { + return compactorDecodingResult{}, err + } + count := binary.LittleEndian.Uint32(buf[index : index+4]) + index += 4 + + var ( + minItem = float32(math.MaxFloat32) + maxItem = float32(-math.MaxFloat32) + ) + items := make([]float32, 0, count) + for i := uint32(0); i < count; i++ { + if err := validateBuffer(buf, index+4); err != nil { + return compactorDecodingResult{}, err + } + + item := math.Float32frombits(binary.LittleEndian.Uint32(buf[index : index+4])) + items = append(items, item) + + minItem = min(minItem, item) + maxItem = max(maxItem, item) + + index += 4 + } + + delta := 2 * int(sectionSize) * int(numSections) + nomCap := 2 * delta + capacity := max(int(count), nomCap) + + if isHighRankAccuracyMode { + newItems := make([]float32, capacity) + copy(newItems[capacity-int(count):], items) + items = newItems + } + + return compactorDecodingResult{ + compactor: &compactor{ + lgWeight: lgWeight, + items: items, + count: int(count), + delta: delta, + sorted: isLevel0Sorted, + isHighRankAccuracyMode: isHighRankAccuracyMode, + sectionSizeFlt: sectionSizeFloat, + numSections: numSections, + state: int64(state), + coin: false, + sectionSize: nearestEven(sectionSizeFloat), + }, + bufferEndIndex: index, + minItem: minItem, + maxItem: maxItem, + n: int64(count), + }, nil +} + +type compactorDecoder struct { + isLevel0Sorted bool + isHighRankAccuracyMode bool +} + +func newCompactorDecoder(isLevel0Sorted, isHighRankAccuracyMode bool) compactorDecoder { + return compactorDecoder{ + isLevel0Sorted: isLevel0Sorted, + isHighRankAccuracyMode: isHighRankAccuracyMode, + } +} + +func (d *compactorDecoder) Decode(r io.Reader) (compactorDecodingResult, error) { + var state uint64 + if err := binary.Read(r, binary.LittleEndian, &state); err != nil { + return compactorDecodingResult{}, err + } + + var sectionSizeFltRaw uint32 + if err := binary.Read(r, binary.LittleEndian, §ionSizeFltRaw); err != nil { + return compactorDecodingResult{}, err + } + sectionSizeFlt := math.Float32frombits(sectionSizeFltRaw) + sectionSize := math.Round(float64(sectionSizeFlt)) + + var lgWeight byte + if err := binary.Read(r, binary.LittleEndian, &lgWeight); err != nil { + return compactorDecodingResult{}, err + } + + var numSections byte + if err := binary.Read(r, binary.LittleEndian, &numSections); err != nil { + return compactorDecodingResult{}, err + } + + var pad uint16 + if err := binary.Read(r, binary.LittleEndian, &pad); err != nil { + return compactorDecodingResult{}, err + } + + var count uint32 + if err := binary.Read(r, binary.LittleEndian, &count); err != nil { + return compactorDecodingResult{}, err + } + + var ( + minItem = float32(math.MaxFloat32) + maxItem = float32(-math.MaxFloat32) + ) + items := make([]float32, 0, count) + for i := uint32(0); i < count; i++ { + var itemRaw uint32 + if err := binary.Read(r, binary.LittleEndian, &itemRaw); err != nil { + return compactorDecodingResult{}, err + } + + item := math.Float32frombits(itemRaw) + items = append(items, item) + + minItem = min(minItem, item) + maxItem = max(maxItem, item) + } + + delta := 2 * int(sectionSize) * int(numSections) + nomCap := 2 * delta + capacity := max(int(count), nomCap) + + if d.isHighRankAccuracyMode { + newItems := make([]float32, capacity) + copy(newItems[capacity-int(count):], items) + items = newItems + } + + return compactorDecodingResult{ + compactor: &compactor{ + lgWeight: lgWeight, + items: items, + count: int(count), + delta: delta, + sorted: d.isLevel0Sorted, + isHighRankAccuracyMode: d.isHighRankAccuracyMode, + sectionSizeFlt: sectionSizeFlt, + numSections: numSections, + state: int64(state), + coin: false, + sectionSize: nearestEven(sectionSizeFlt), + }, + minItem: minItem, + maxItem: maxItem, + n: int64(count), + }, nil +} diff --git a/req/compactor_test.go b/req/compactor_test.go index 87e379f..3b38dfb 100644 --- a/req/compactor_test.go +++ b/req/compactor_test.go @@ -18,6 +18,8 @@ package req import ( + "bytes" + "encoding/binary" "fmt" "math" "testing" @@ -419,3 +421,203 @@ func TestCompactorGetters(t *testing.T) { assert.True(t, c.IsHighRankAccuracyMode()) assert.Equal(t, int64(0), c.State()) } + +// TestCompactorItemsSerDe tests marshalItems/decode round-trip with raw array +// position checks, ported from Java ReqFloatBufferTest.checkSerDe. +func TestCompactorItemsSerDe(t *testing.T) { + t.Run("HRA", func(t *testing.T) { + runCompactorItemsSerDe(t, true) + }) + t.Run("LRA", func(t *testing.T) { + runCompactorItemsSerDe(t, false) + }) +} + +func runCompactorItemsSerDe(t *testing.T, hra bool) { + t.Helper() + c := newCompactor(0, hra, minK) + initialCap := c.Capacity() + + // Append more items than initial capacity to trigger growth. + numItems := initialCap + 1 + for i := 0; i < numItems; i++ { + c.Append(float32(i)) + } + assert.Greater(t, c.Capacity(), initialCap) + + capacity := c.Capacity() + count := c.Count() + sorted := c.sorted + + // Verify raw item positions before serialization. + if hra { + assert.Equal(t, float32(numItems-1), c.items[capacity-count]) + assert.Equal(t, float32(0), c.items[capacity-1]) + } else { + assert.Equal(t, float32(0), c.items[0]) + assert.Equal(t, float32(numItems-1), c.items[count-1]) + } + + // Verify marshalItems byte output matches Item(i) order. + itemBytes := c.marshalItems() + assert.Equal(t, count*4, len(itemBytes)) + for i := 0; i < count; i++ { + bits := binary.LittleEndian.Uint32(itemBytes[i*4:]) + got := math.Float32frombits(bits) + assert.Equal(t, c.Item(i), got, "serialized item mismatch at offset %d", i) + } + + // Full round-trip via MarshalBinary + decodeCompactor. + fullBytes, err := c.MarshalBinary() + assert.NoError(t, err) + result, err := decodeCompactor(fullBytes, 0, sorted, hra) + assert.NoError(t, err) + c2 := result.compactor + + assert.Equal(t, count, c2.Count()) + assert.Equal(t, sorted, c2.sorted) + assert.Equal(t, hra, c2.isHighRankAccuracyMode) + + // Verify raw positions in deserialized compactor. + if hra { + cap2 := c2.Capacity() + assert.Equal(t, float32(numItems-1), c2.items[cap2-c2.count]) + assert.Equal(t, float32(0), c2.items[cap2-1]) + } else { + assert.Equal(t, float32(0), c2.items[0]) + assert.Equal(t, float32(numItems-1), c2.items[c2.count-1]) + } + + // Verify all items match by logical offset. + for i := 0; i < count; i++ { + assert.Equal(t, c.Item(i), c2.Item(i), "item mismatch at offset %d", i) + } +} + +func TestCompactorSerializationDeserialization(t *testing.T) { + t.Run("LRA", func(t *testing.T) { + runCompactorSerializationDeserialization(t, 12, false) + }) + t.Run("HRA", func(t *testing.T) { + runCompactorSerializationDeserialization(t, 12, true) + }) +} + +func runCompactorSerializationDeserialization(t *testing.T, k int, hra bool) { + t.Helper() + c1 := newCompactor(0, hra, k) + nomCap := nomCapMul * initNumberOfSections * k + expectedCap := 2 * nomCap + expectedDelta := nomCap + + for i := 1; i <= nomCap; i++ { + c1.Append(float32(i)) + } + + sectionSizeFlt := c1.SectionSizeFlt() + sectionSize := c1.SectionSize() + numSections := c1.NumSections() + state := c1.State() + lgWt := c1.lgWeight + sorted := c1.sorted + + // serialize + c1ser, err := c1.MarshalBinary() + assert.NoError(t, err) + + // deserialize via buffer + result, err := decodeCompactor(c1ser, 0, sorted, hra) + assert.NoError(t, err) + c2 := result.compactor + + assert.Equal(t, float32(1), result.minItem) + assert.Equal(t, float32(nomCap), result.maxItem) + assert.Equal(t, int64(nomCap), result.n) + assert.Equal(t, sectionSizeFlt, c2.SectionSizeFlt()) + assert.Equal(t, sectionSize, c2.SectionSize()) + assert.Equal(t, numSections, c2.NumSections()) + assert.Equal(t, state, c2.State()) + assert.Equal(t, lgWt, c2.lgWeight) + assert.Equal(t, hra, c2.IsHighRankAccuracyMode()) + if hra { + assert.Equal(t, expectedCap, c2.Capacity()) + } else { + // LRA decoder keeps items at count-sized capacity (not expanded to nomCap). + assert.Equal(t, nomCap, c2.Capacity()) + } + assert.Equal(t, expectedDelta, c2.delta) + + for i := 0; i < nomCap; i++ { + assert.Equal(t, c1.Item(i), c2.Item(i), "item mismatch at offset %d", i) + } + + // deserialize via stream + decoder := newCompactorDecoder(sorted, hra) + result2, err := decoder.Decode(bytes.NewReader(c1ser)) + assert.NoError(t, err) + c3 := result2.compactor + + assert.Equal(t, float32(1), result2.minItem) + assert.Equal(t, float32(nomCap), result2.maxItem) + assert.Equal(t, int64(nomCap), result2.n) + assert.Equal(t, sectionSizeFlt, c3.SectionSizeFlt()) + assert.Equal(t, sectionSize, c3.SectionSize()) + for i := 0; i < nomCap; i++ { + assert.Equal(t, c1.Item(i), c3.Item(i), "stream: item mismatch at offset %d", i) + } +} + +func TestCompactorSerializationDeserializationWithNegativeValues(t *testing.T) { + t.Run("LRA", func(t *testing.T) { + runCompactorSerializationDeserializationNegative(t, 12, false) + }) + t.Run("HRA", func(t *testing.T) { + runCompactorSerializationDeserializationNegative(t, 12, true) + }) +} + +func runCompactorSerializationDeserializationNegative(t *testing.T, k int, hra bool) { + t.Helper() + c1 := newCompactor(0, hra, k) + nomCap := nomCapMul * initNumberOfSections * k + + for i := 1; i <= nomCap; i++ { + c1.Append(float32(-i)) + } + + c1ser, err := c1.MarshalBinary() + assert.NoError(t, err) + + result, err := decodeCompactor(c1ser, 0, c1.sorted, hra) + assert.NoError(t, err) + assert.Equal(t, float32(-nomCap), result.minItem) + assert.Equal(t, float32(-1), result.maxItem) +} + +func TestCompactorSerializationDeserializationWithMixedValues(t *testing.T) { + t.Run("LRA", func(t *testing.T) { + runCompactorSerializationDeserializationMixed(t, 12, false) + }) + t.Run("HRA", func(t *testing.T) { + runCompactorSerializationDeserializationMixed(t, 12, true) + }) +} + +func runCompactorSerializationDeserializationMixed(t *testing.T, k int, hra bool) { + t.Helper() + c1 := newCompactor(0, hra, k) + nomCap := nomCapMul * initNumberOfSections * k + half := nomCap / 2 + + for i := 0; i < nomCap; i++ { + c1.Append(float32(i - half)) + } + + c1ser, err := c1.MarshalBinary() + assert.NoError(t, err) + + result, err := decodeCompactor(c1ser, 0, c1.sorted, hra) + assert.NoError(t, err) + assert.Equal(t, float32(-half), result.minItem) + assert.Equal(t, float32(half-1), result.maxItem) +} diff --git a/req/decoder.go b/req/decoder.go new file mode 100644 index 0000000..ca94b32 --- /dev/null +++ b/req/decoder.go @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package req + +import ( + "encoding/binary" + "fmt" + "io" + "math" + + "github.com/apache/datasketches-go/internal" +) + +// Decoder is responsible for decoding sketches from binary format. +type Decoder struct{} + +// NewDecoder creates a new instance of Decoder. +func NewDecoder() Decoder { + return Decoder{} +} + +// Decode decodes a sketch from the provided reader. +func (d *Decoder) Decode(r io.Reader) (*Sketch, error) { + var preambleInt byte + if err := binary.Read(r, binary.LittleEndian, &preambleInt); err != nil { + return nil, err + } + + var serVer byte + if err := binary.Read(r, binary.LittleEndian, &serVer); err != nil { + return nil, err + } + if serVer != serialVersion { + return nil, fmt.Errorf("unsupported serialization version: %d", serVer) + } + + var familyID byte + if err := binary.Read(r, binary.LittleEndian, &familyID); err != nil { + return nil, err + } + if int(familyID) != internal.FamilyEnum.REQ.Id { + return nil, fmt.Errorf("invalid family id: %d", familyID) + } + + var flags byte + if err := binary.Read(r, binary.LittleEndian, &flags); err != nil { + return nil, err + } + isEmpty := (flags & 4) > 0 + isHighRankAccuracyMode := (flags & 8) > 0 + isRawItemsSketch := (flags & 16) > 0 + isLevel0Sorted := (flags & 32) > 0 + + var k uint16 + if err := binary.Read(r, binary.LittleEndian, &k); err != nil { + return nil, err + } + + var numCompactors byte + if err := binary.Read(r, binary.LittleEndian, &numCompactors); err != nil { + return nil, err + } + + var numRawItems byte + if err := binary.Read(r, binary.LittleEndian, &numRawItems); err != nil { + return nil, err + } + + format := inferEncodingFormat(isEmpty, isRawItemsSketch, int(numCompactors)) + switch format { + case encodingFormatEmpty: + if preambleInt != 2 { + return nil, fmt.Errorf("invalid preamble: %d", preambleInt) + } + return NewSketch(WithK(int(k)), WithHighRankAccuracyMode(isHighRankAccuracyMode)) + case encodingFormatRawItems: + if preambleInt != 2 { + return nil, fmt.Errorf("invalid preamble: %d", preambleInt) + } + + sk, err := NewSketch(WithK(int(k)), WithHighRankAccuracyMode(isHighRankAccuracyMode)) + if err != nil { + return nil, err + } + + for i := byte(0); i < numRawItems; i++ { + var rawItem uint32 + if err := binary.Read(r, binary.LittleEndian, &rawItem); err != nil { + return nil, err + } + + if err := sk.Update(math.Float32frombits(rawItem)); err != nil { + return nil, err + } + } + return sk, nil + case encodingFormatExact: + if preambleInt != 2 { + return nil, fmt.Errorf("invalid preamble: %d", preambleInt) + } + + decoder := newCompactorDecoder(isLevel0Sorted, isHighRankAccuracyMode) + result, err := decoder.Decode(r) + if err != nil { + return nil, err + } + + sk := &Sketch{ + n: result.n, + compactors: []*compactor{result.compactor}, + minItem: result.minItem, + maxItem: result.maxItem, + k: int(k), + isHighRankAccuracyMode: isHighRankAccuracyMode, + } + if err := sk.validateK(); err != nil { + return nil, err + } + sk.maxNomSize = sk.computeMaxNomSize() + sk.numRetained = sk.computeRetainedItems() + return sk, nil + default: // Estimation. + if preambleInt != 4 { + return nil, fmt.Errorf("invalid preamble: %d", preambleInt) + } + + var n uint64 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + + var minItemRaw uint32 + if err := binary.Read(r, binary.LittleEndian, &minItemRaw); err != nil { + return nil, err + } + minItem := math.Float32frombits(minItemRaw) + + var maxItemRaw uint32 + if err := binary.Read(r, binary.LittleEndian, &maxItemRaw); err != nil { + return nil, err + } + maxItem := math.Float32frombits(maxItemRaw) + + compactors := make([]*compactor, 0, int(numCompactors)) + for i := 0; i < int(numCompactors); i++ { + if i == 0 { + decoder := newCompactorDecoder(isLevel0Sorted, isHighRankAccuracyMode) + result, err := decoder.Decode(r) + if err != nil { + return nil, err + } + + compactors = append(compactors, result.compactor) + } else { + decoder := newCompactorDecoder(true, isHighRankAccuracyMode) + result, err := decoder.Decode(r) + if err != nil { + return nil, err + } + + compactors = append(compactors, result.compactor) + } + } + + sk := &Sketch{ + k: int(k), + isHighRankAccuracyMode: isHighRankAccuracyMode, + n: int64(n), + minItem: minItem, + maxItem: maxItem, + compactors: compactors, + } + if err := sk.validateK(); err != nil { + return nil, err + } + sk.maxNomSize = sk.computeMaxNomSize() + sk.numRetained = sk.computeRetainedItems() + return sk, nil + } +} + +// Decode decodes a sketch from the provided buffer. +// If the buffer is too short, returns io.ErrUnexpectedEOF. +func Decode(buf []byte) (*Sketch, error) { + index := 0 + if err := validateBuffer(buf, index+1); err != nil { + return nil, err + } + preambleInts := buf[index] + index++ + + if err := validateBuffer(buf, index+1); err != nil { + return nil, err + } + serVer := buf[index] + index++ + if serVer != serialVersion { + return nil, fmt.Errorf("unsupported serialization version: %d", serVer) + } + + if err := validateBuffer(buf, index+1); err != nil { + return nil, err + } + familyID := buf[index] + index++ + if int(familyID) != internal.FamilyEnum.REQ.Id { + return nil, fmt.Errorf("invalid family id: %d", familyID) + } + + // flags. + if err := validateBuffer(buf, index+1); err != nil { + return nil, err + } + flags := buf[index] + index++ + isEmpty := (flags & 4) > 0 + isHighRankAccuracyMode := (flags & 8) > 0 + isRawItemsSketch := (flags & 16) > 0 + isLevel0Sorted := (flags & 32) > 0 + + if err := validateBuffer(buf, index+2); err != nil { + return nil, err + } + k := binary.LittleEndian.Uint16(buf[index : index+2]) + index += 2 + + if err := validateBuffer(buf, index+1); err != nil { + return nil, err + } + numCompactors := buf[index] + index++ + + if err := validateBuffer(buf, index+1); err != nil { + return nil, err + } + numRawItems := buf[index] + index++ + + format := inferEncodingFormat(isEmpty, isRawItemsSketch, int(numCompactors)) + switch format { + case encodingFormatEmpty: + if preambleInts != 2 { + return nil, fmt.Errorf("invalid preamble ints for empty encoding format: %d", preambleInts) + } + return NewSketch(WithK(int(k)), WithHighRankAccuracyMode(isHighRankAccuracyMode)) + case encodingFormatRawItems: + if preambleInts != 2 { + return nil, fmt.Errorf("invalid preamble ints for raw items encoding: %d", preambleInts) + } + + sk, err := NewSketch(WithK(int(k)), WithHighRankAccuracyMode(isHighRankAccuracyMode)) + if err != nil { + return nil, err + } + + for i := 0; i < int(numRawItems); i++ { + if err := validateBuffer(buf, index+4); err != nil { + return nil, err + } + + item := math.Float32frombits(binary.LittleEndian.Uint32(buf[index : index+4])) + + if err := sk.Update(item); err != nil { + return nil, err + } + + index += 4 + } + return sk, nil + case encodingFormatExact: + if preambleInts != 2 { + return nil, fmt.Errorf("invalid preamble ints for exact encoding: %d", preambleInts) + } + + result, err := decodeCompactor(buf, index, isLevel0Sorted, isHighRankAccuracyMode) + if err != nil { + return nil, err + } + + sk := &Sketch{ + n: result.n, + compactors: []*compactor{result.compactor}, + minItem: result.minItem, + maxItem: result.maxItem, + k: int(k), + isHighRankAccuracyMode: isHighRankAccuracyMode, + } + if err := sk.validateK(); err != nil { + return nil, err + } + sk.maxNomSize = sk.computeMaxNomSize() + sk.numRetained = sk.computeRetainedItems() + return sk, nil + default: // Estimation. + if err := validateBuffer(buf, index+8); err != nil { + return nil, err + } + n := binary.LittleEndian.Uint64(buf[index : index+8]) + index += 8 + + if err := validateBuffer(buf, index+4); err != nil { + return nil, err + } + minItem := math.Float32frombits(binary.LittleEndian.Uint32(buf[index : index+4])) + index += 4 + + if err := validateBuffer(buf, index+4); err != nil { + return nil, err + } + maxItem := math.Float32frombits(binary.LittleEndian.Uint32(buf[index : index+4])) + index += 4 + + compactors := make([]*compactor, 0, int(numCompactors)) + for i := 0; i < int(numCompactors); i++ { + if i == 0 { + res, err := decodeCompactor(buf, index, isLevel0Sorted, isHighRankAccuracyMode) + if err != nil { + return nil, err + } + + compactors = append(compactors, res.compactor) + index = res.bufferEndIndex + } else { + res, err := decodeCompactor(buf, index, true, isHighRankAccuracyMode) + if err != nil { + return nil, err + } + + compactors = append(compactors, res.compactor) + index = res.bufferEndIndex + } + } + + sk := &Sketch{ + k: int(k), + isHighRankAccuracyMode: isHighRankAccuracyMode, + n: int64(n), + minItem: minItem, + maxItem: maxItem, + compactors: compactors, + } + if err := sk.validateK(); err != nil { + return nil, err + } + sk.maxNomSize = sk.computeMaxNomSize() + sk.numRetained = sk.computeRetainedItems() + return sk, nil + } +} + +func validateBuffer(buf []byte, endIndex int) error { + if len(buf) < endIndex { + return io.ErrUnexpectedEOF + } + return nil +} + +func inferEncodingFormat(isEmpty, isRawItemsSketch bool, numCompactors int) encodingFormat { + if numCompactors <= 1 { + if isEmpty { + return encodingFormatEmpty + } + if isRawItemsSketch { + return encodingFormatRawItems + } + return encodingFormatExact + } + return encodingFormatEstimation +} diff --git a/req/encoder.go b/req/encoder.go new file mode 100644 index 0000000..5431640 --- /dev/null +++ b/req/encoder.go @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package req + +import ( + "encoding/binary" + "io" + "math" + + "github.com/apache/datasketches-go/internal" +) + +type encodingFormat int + +const ( + encodingFormatEmpty encodingFormat = iota + encodingFormatRawItems encodingFormat = iota + encodingFormatExact encodingFormat = iota + encodingFormatEstimation encodingFormat = iota +) + +const ( + serialVersion = 1 +) + +// Encoder encodes a REQ sketch to bytes. +type Encoder struct { + w io.Writer +} + +// NewEncoder creates a new encoder. +func NewEncoder(w io.Writer) Encoder { + return Encoder{w: w} +} + +// Encode encodes a REQ sketch to bytes. +func (enc *Encoder) Encode(sketch *Sketch) error { + format := sketch.computeEncodingFormat() + + preambleInts := byte(2) + if format == encodingFormatEstimation { + preambleInts = byte(4) + } + if err := binary.Write(enc.w, binary.LittleEndian, preambleInts); err != nil { + return err + } + + if err := binary.Write(enc.w, binary.LittleEndian, byte(serialVersion)); err != nil { + return err + } + + if err := binary.Write(enc.w, binary.LittleEndian, byte(internal.FamilyEnum.REQ.Id)); err != nil { + return err + } + + flags := sketch.encodingFlags() + if err := binary.Write(enc.w, binary.LittleEndian, byte(flags)); err != nil { + return err + } + + if err := binary.Write(enc.w, binary.LittleEndian, uint16(sketch.K())); err != nil { + return err + } + + numCompactors := byte(0) + if !sketch.IsEmpty() { + numCompactors = byte(sketch.numLevels()) + } + if err := binary.Write(enc.w, binary.LittleEndian, numCompactors); err != nil { + return err + } + + numRawItems := byte(0) + if sketch.N() <= minK { + numRawItems = byte(sketch.N()) + } + if err := binary.Write(enc.w, binary.LittleEndian, numRawItems); err != nil { + return err + } + + switch format { + case encodingFormatEmpty: + return nil + case encodingFormatRawItems: + c0 := sketch.compactors[0] + for i := 0; i < int(numRawItems); i++ { + if err := binary.Write(enc.w, binary.LittleEndian, math.Float32bits(c0.Item(i))); err != nil { + return err + } + } + return nil + case encodingFormatExact: + c0 := sketch.compactors[0] + b, err := c0.MarshalBinary() + if err != nil { + return err + } + + if _, err := enc.w.Write(b); err != nil { + return err + } + return nil + default: // Estimation. + if err := binary.Write(enc.w, binary.LittleEndian, uint64(sketch.N())); err != nil { + return err + } + + if err := binary.Write(enc.w, binary.LittleEndian, math.Float32bits(sketch.minItem)); err != nil { + return err + } + + if err := binary.Write(enc.w, binary.LittleEndian, math.Float32bits(sketch.maxItem)); err != nil { + return err + } + + for i := 0; i < int(numCompactors); i++ { + c := sketch.compactors[i] + b, err := c.MarshalBinary() + if err != nil { + return err + } + + if _, err := enc.w.Write(b); err != nil { + return err + } + } + return nil + } +} diff --git a/req/sketch.go b/req/sketch.go index 8d5c866..d7e2828 100644 --- a/req/sketch.go +++ b/req/sketch.go @@ -18,12 +18,14 @@ package req import ( + "encoding/binary" "errors" "fmt" "math" "strings" quantilecommon "github.com/apache/datasketches-go/common/quantiles" + "github.com/apache/datasketches-go/internal" ) const ( @@ -725,6 +727,148 @@ func (s *Sketch) compress() error { return nil } +// SerializedSizeBytes returns the current number of bytes this Sketch would require if serialized. +func (s *Sketch) SerializedSizeBytes() int { + format := s.computeEncodingFormat() + return s.computeSerializedSizeBytes(format) +} + +func (s *Sketch) computeSerializedSizeBytes(format encodingFormat) int { + switch format { + case encodingFormatEmpty: + return 8 + case encodingFormatRawItems: + return s.compactors[0].Count()*4 + 8 + case encodingFormatExact: + return s.compactors[0].SerializationBytes() + 8 + default: // estimation. + sum := 0 + for _, comp := range s.compactors { + sum += comp.SerializationBytes() + } + return sum + 24 + } +} + +func (s *Sketch) computeEncodingFormat() encodingFormat { + if s.IsEmpty() { + return encodingFormatEmpty + } + if s.N() <= minK { + return encodingFormatRawItems + } + if s.numLevels() == 1 { + return encodingFormatExact + } + return encodingFormatEstimation +} + +// MarshalBinary serializes the sketch into a binary format. +func (s *Sketch) MarshalBinary() ([]byte, error) { + format := s.computeEncodingFormat() + totalBytes := s.computeSerializedSizeBytes(format) + + var ( + buf = make([]byte, totalBytes) + index = 0 + ) + + preambleInts := 2 + if format == encodingFormatEstimation { + preambleInts = 4 + } + buf[index] = byte(preambleInts) + index++ + + buf[index] = byte(serialVersion) + index++ + + buf[index] = byte(internal.FamilyEnum.REQ.Id) + index++ + + flags := s.encodingFlags() + buf[index] = byte(flags) + index++ + + binary.LittleEndian.PutUint16(buf[index:], uint16(s.K())) + index += 2 + + numCompactors := 0 + if !s.IsEmpty() { + numCompactors = s.numLevels() + } + buf[index] = byte(numCompactors) + index++ + + numRawItems := 0 + if s.N() <= minK { + numRawItems = int(s.N()) + } + buf[index] = byte(numRawItems) + index++ + + switch format { + case encodingFormatEmpty: + return buf, nil + case encodingFormatRawItems: + c0 := s.compactors[0] + for i := 0; i < numRawItems; i++ { + binary.LittleEndian.PutUint32(buf[index:], math.Float32bits(c0.Item(i))) + index += 4 + } + return buf, nil + case encodingFormatExact: + c0 := s.compactors[0] + b, err := c0.MarshalBinary() + if err != nil { + return nil, err + } + copy(buf[index:], b) + return buf, nil + default: // Estimation. + binary.LittleEndian.PutUint64(buf[index:], uint64(s.N())) + index += 8 + + binary.LittleEndian.PutUint32(buf[index:], math.Float32bits(s.minItem)) + index += 4 + + binary.LittleEndian.PutUint32(buf[index:], math.Float32bits(s.maxItem)) + index += 4 + + for i := 0; i < numCompactors; i++ { + c := s.compactors[i] + b, err := c.MarshalBinary() + if err != nil { + return nil, err + } + copy(buf[index:], b) + index += len(b) + } + + return buf, nil + } +} + +func (s *Sketch) encodingFlags() int { + flags := 0 + if s.IsEmpty() { + flags = 4 + } + + if s.IsHighRankAccuracyMode() { + flags |= 8 + } + + if s.N() <= minK { // raw items sketch. + flags |= 16 + } + + if s.compactors[0].sorted { + flags |= 32 + } + return flags +} + func computeRankLowerBound( k int, levels int, diff --git a/req/sketch_serialization_test.go b/req/sketch_serialization_test.go new file mode 100644 index 0000000..5478bb2 --- /dev/null +++ b/req/sketch_serialization_test.go @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package req + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSketchSerializationDeserialization(t *testing.T) { + k := 12 + exact := (nomCapMul * initNumberOfSections * k) - 1 + + testCases := []struct { + name string + hra bool + count int + }{ + {"empty LRA", false, 0}, + {"empty HRA", true, 0}, + {"rawItems LRA", false, 4}, + {"rawItems HRA", true, 4}, + {"exact LRA", false, exact}, + {"exact HRA", true, exact}, + {"estimation LRA", false, 2 * exact}, + {"estimation HRA", true, 2 * exact}, + } + for _, tc := range testCases { + t.Run(tc.name+" MarshalBinary+Decode", func(t *testing.T) { + sk1 := newSketchForTest(t, k, tc.hra, tc.count) + + sk1Arr, err := sk1.MarshalBinary() + assert.NoError(t, err) + + sk2, err := Decode(sk1Arr) + assert.NoError(t, err) + + assertSketchesEqual(t, sk1, sk2) + }) + t.Run(tc.name+" Encoder+Decoder", func(t *testing.T) { + sk1 := newSketchForTest(t, k, tc.hra, tc.count) + + var buf bytes.Buffer + enc := NewEncoder(&buf) + assert.NoError(t, enc.Encode(sk1)) + + dec := NewDecoder() + sk2, err := dec.Decode(bytes.NewReader(buf.Bytes())) + assert.NoError(t, err) + + assertSketchesEqual(t, sk1, sk2) + }) + } +} + +func newSketchForTest(t *testing.T, k int, hra bool, count int) *Sketch { + t.Helper() + sk, err := NewSketch(WithK(k), WithHighRankAccuracyMode(hra)) + assert.NoError(t, err) + for i := 1; i <= count; i++ { + assert.NoError(t, sk.Update(float32(i))) + } + return sk +} + +func assertSketchesEqual(t *testing.T, sk1, sk2 *Sketch) { + t.Helper() + assert.Equal(t, sk1.NumRetained(), sk2.NumRetained()) + assert.Equal(t, sk1.IsEmpty(), sk2.IsEmpty()) + if !sk1.IsEmpty() { + min1, err := sk1.MinItem() + assert.NoError(t, err) + min2, err := sk2.MinItem() + assert.NoError(t, err) + assert.Equal(t, min1, min2) + + max1, err := sk1.MaxItem() + assert.NoError(t, err) + max2, err := sk2.MaxItem() + assert.NoError(t, err) + assert.Equal(t, max1, max2) + } + assert.Equal(t, sk1.N(), sk2.N()) + assert.Equal(t, sk1.IsHighRankAccuracyMode(), sk2.IsHighRankAccuracyMode()) + assert.Equal(t, sk1.K(), sk2.K()) + assert.Equal(t, sk1.maxNomSize, sk2.maxNomSize) + assert.Equal(t, sk1.numLevels(), sk2.numLevels()) + assert.Equal(t, sk1.SerializedSizeBytes(), sk2.SerializedSizeBytes()) +} diff --git a/req/sketch_test.go b/req/sketch_test.go index 5bc111e..ff8db82 100644 --- a/req/sketch_test.go +++ b/req/sketch_test.go @@ -905,6 +905,52 @@ func TestSketchReset(t *testing.T) { }) } +func TestSketchSerializedSizeBytes(t *testing.T) { + t.Run("empty", func(t *testing.T) { + sk, err := NewSketch() + assert.NoError(t, err) + assert.Equal(t, 8, sk.SerializedSizeBytes()) + }) + + t.Run("raw items", func(t *testing.T) { + sk, err := NewSketch() + assert.NoError(t, err) + for i := 1; i <= 4; i++ { + assert.NoError(t, sk.Update(float32(i))) + } + assert.Equal(t, int64(4), sk.N()) + assert.Equal(t, 4*4+8, sk.SerializedSizeBytes()) + + b, err := sk.MarshalBinary() + assert.NoError(t, err) + assert.Equal(t, sk.SerializedSizeBytes(), len(b)) + }) + + t.Run("exact", func(t *testing.T) { + sk, err := NewSketch(WithK(20)) + assert.NoError(t, err) + for i := 1; i <= 10; i++ { + assert.NoError(t, sk.Update(float32(i))) + } + assert.False(t, sk.IsEstimationMode()) + assert.Greater(t, sk.SerializedSizeBytes(), 8) + + b, err := sk.MarshalBinary() + assert.NoError(t, err) + assert.Equal(t, sk.SerializedSizeBytes(), len(b)) + }) + + t.Run("estimation", func(t *testing.T) { + sk := loadSketch(t, 20, 1, 200, true, true) + assert.True(t, sk.IsEstimationMode()) + assert.Greater(t, sk.SerializedSizeBytes(), 24) + + b, err := sk.MarshalBinary() + assert.NoError(t, err) + assert.Equal(t, sk.SerializedSizeBytes(), len(b)) + }) +} + func TestSketchMinMaxItem(t *testing.T) { t.Run("sequential values", func(t *testing.T) { sk := loadSketch(t, 12, 1, 100, true, true)