worryg0d commented on code in PR #1946:
URL:
https://github.com/apache/cassandra-gocql-driver/pull/1946#discussion_r3288013489
##########
segment_codec.go:
##########
@@ -0,0 +1,277 @@
+// segment_codec.go
+
+package gocql
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+)
+
+const (
+ maxSegmentPayloadSize = 1<<17 - 1
+
+ compressedHeaderSize = 5 + crc24Size
+ uncompressedHeaderSize = 3 + crc24Size
+
+ crc24Size = 3
+ crc32Size = 4
+)
+
+// segmentHeader represents the header information of a segment.
+type segmentHeader struct {
+ // payload length is the length of the segment payload
+ payloadLength int
+ // uncompressedPayloadLength is the length of the uncompressed payload
(only for compressed segments)
+ uncompressedPayloadLength int
+ // indicates whether the segment contains only completed frames
+ isSelfContained bool
+}
+
+func (segment *segmentHeader) String() string {
+ return fmt.Sprintf("segmentHeader(len=%d, uncompressedLen=%d,
isSelfContained=%v)",
+ segment.payloadLength,
+ segment.uncompressedPayloadLength,
+ segment.isSelfContained)
+}
+
+type segmentCodec struct {
+ compressor Compressor
+ compressed bool
+}
+
+func newSegmentCodec(compressor Compressor) segmentCodec {
+ return segmentCodec{
+ compressed: compressor != nil,
+ compressor: compressor,
+ }
+}
+
+func (sc *segmentCodec) encode(payload []byte, isSelfContained bool) ([]byte,
error) {
+ if len(payload) > maxSegmentPayloadSize {
+ return nil, fmt.Errorf("gocql: payload length (%d) exceeds
maximum segment size of %d", len(payload), maxSegmentPayloadSize)
+ }
+
+ if sc.compressed {
+ return sc.encodeCompressedSegment(payload, isSelfContained)
+ }
+ return sc.encodeUncompressedSegment(payload, isSelfContained)
+}
+
+func (sc *segmentCodec) encodeCompressedSegment(payload []byte,
isSelfContained bool) ([]byte, error) {
+ uncompressedLen := len(payload)
+
+ compressed, err := sc.compressor.AppendCompressed(nil, payload)
+ if err != nil {
+ return nil, err
+ }
+
+ compressedLen := len(compressed)
+
+ // If compression is not worth it, we should send uncompressed data
+ // following the next logic:
+ if uncompressedLen < compressedLen {
+ compressed = payload
+ compressedLen = uncompressedLen
+ uncompressedLen = 0
+ }
+
+ segmentBuf := make([]byte, compressedHeaderSize+compressedLen+crc32Size)
+
+ sc.encodeCompressedSegmentHeader(compressedLen, uncompressedLen,
isSelfContained, segmentBuf)
+ sc.encodePayloadAndChecksum(compressed,
segmentBuf[compressedHeaderSize:])
+
+ return segmentBuf, nil
+}
+
+// encodeCompressedSegmentHeader encodes the compressed segment header into
the provided destination slice.
+// It assumes that dest has enough space to hold the header.
+func (sc *segmentCodec) encodeCompressedSegmentHeader(compressedLen,
uncompressedLen int, isSelfContained bool, dest []byte) {
+ combined := uint64(compressedLen) | uint64(uncompressedLen)<<17
+ if isSelfContained {
+ combined |= 1 << 34
+ }
+
+ binary.LittleEndian.PutUint64(dest[:], combined)
+
+ headerCRC24 := Crc24(dest[:5])
+ dest[5] = byte(headerCRC24)
+ dest[6] = byte(headerCRC24 >> 8)
+ dest[7] = byte(headerCRC24 >> 16)
+}
+
+func (sc *segmentCodec) encodeUncompressedSegment(payload []byte,
isSelfContained bool) ([]byte, error) {
+ payloadLen := len(payload)
+
+ segmentBuf := make([]byte, uncompressedHeaderSize+payloadLen+crc32Size)
+
+ sc.encodeUncompressedSegmentHeader(payloadLen, isSelfContained,
segmentBuf)
+ sc.encodePayloadAndChecksum(payload,
segmentBuf[uncompressedHeaderSize:])
+
+ return segmentBuf, nil
+}
+
+// encodeUncompressedSegmentHeader encodes the uncompressed segment header
into the provided destination slice.
+// It assumes that dest has enough space to hold the header.
+func (sc *segmentCodec) encodeUncompressedSegmentHeader(payloadLen int,
isSelfContained bool, dest []byte) {
+ headerInt := uint32(payloadLen)
+ if isSelfContained {
+ headerInt |= 1 << 17
+ }
+
+ dest[0] = byte(headerInt)
+ dest[1] = byte(headerInt >> 8)
+ dest[2] = byte(headerInt >> 16)
+
+ crc := Crc24(dest[:3])
+ dest[3] = byte(crc)
+ dest[4] = byte(crc >> 8)
+ dest[5] = byte(crc >> 16)
+}
+
+// encodePayloadAndChecksum encodes the payload and its CRC32 checksum into
the provided destination slice.
+// It assumes that dest has enough space to hold the payload and checksum.
+// Starting from dest[0], it writes the payload followed by its CRC32 checksum.
+func (sc *segmentCodec) encodePayloadAndChecksum(payload []byte, dest []byte) {
+ payloadCRC32 := Crc32(payload)
+ copy(dest, payload)
+ binary.LittleEndian.PutUint32(dest[len(payload):], payloadCRC32)
+}
+
+func (sc *segmentCodec) decode(r io.Reader) ([]byte, bool, error) {
+ if sc.compressed {
+ return sc.decodeCompressedSegment(r)
+ }
+ return sc.decodeUncompressedSegment(r)
+}
+
+func (sc *segmentCodec) decodeCompressedSegment(r io.Reader) ([]byte, bool,
error) {
+ header, err := sc.decodeCompressedSegmentHeader(r)
+ if err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read compressed
segment header, err: %w", err)
+ }
+
+ compressedPayload, err := sc.decodePayload(r, header)
+ if err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read compressed
segment payload, err: %w", err)
+ }
+
+ var uncompressedPayload []byte
+ if header.uncompressedPayloadLength > 0 {
+ uncompressedPayload, err =
sc.compressor.AppendDecompressed(nil, compressedPayload,
uint32(header.uncompressedPayloadLength))
+ if err != nil {
+ return nil, false, err
+ }
+ // Verify that the decompressed length matches the expected
length
+ if uint32(len(uncompressedPayload)) !=
uint32(header.uncompressedPayloadLength) {
+ return nil, false, fmt.Errorf("gocql: length mismatch
after payload decompressing, got %d, expected %d", len(uncompressedPayload),
header.uncompressedPayloadLength)
+ }
+ } else {
+ // in case when the segment was not compressed because
compression was not worth it
+ uncompressedPayload = compressedPayload
+ }
+
+ return uncompressedPayload, header.isSelfContained, nil
+}
+
+func (sc *segmentCodec) decodeUncompressedSegment(r io.Reader) ([]byte, bool,
error) {
+ header, err := sc.decodeUncompressedSegmentHeader(r)
+ if err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read
uncompressed segment header, err: %w", err)
+ }
+
+ payload, err := sc.decodePayload(r, header)
+ if err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read
uncompressed segment payload, err: %w", err)
+ }
+
+ return payload, header.isSelfContained, nil
+}
+
+// verifySegmentHeaderChecksum verifies the CRC24 checksum of the segment
header.
+func (sc *segmentCodec) verifySegmentHeaderChecksum(data []byte, expected
uint32) error {
+ computed := Crc24(data)
+ if computed != expected {
+ return fmt.Errorf("gocql: crc24 mismatch in segment header:
expected %d, got %d", expected, computed)
+ }
+ return nil
+}
+
+// verifySegmentPayloadChecksum verifies the CRC32 checksum of the segment
payload.
+func (sc *segmentCodec) verifySegmentPayloadChecksum(data []byte, expected
uint32) error {
+ computed := Crc32(data)
+ if computed != expected {
+ return fmt.Errorf("gocql: payload crc32 mismatch in segment
payload: expected %d, got %d", expected, computed)
+ }
+ return nil
+}
+
+// decodeCompressedSegmentHeader reads and verifies the header of a compressed
segment from the given reader.
+func (sc *segmentCodec) decodeCompressedSegmentHeader(r io.Reader)
(*segmentHeader, error) {
+ var headerBuf [8]byte // TODO: potentially optimize allocation, could
be stored in segmentCodec and reused if the codec is a specific for each Conn
Review Comment:
Added reusable buffers for headers and payload checksums
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]