This is an automated email from the ASF dual-hosted git repository. He-Pin pushed a commit to branch optimize-bytestring-two-part in repository https://gitbox.apache.org/repos/asf/pekko.git
commit 5c915d04249a4313cbe76d12d53463d1578de8aa Author: He-Pin <[email protected]> AuthorDate: Mon Apr 27 18:58:34 2026 +0800 feat: optimize two-part ByteString concatenation Motivation: Small ByteString concatenation currently promotes two fragments into ByteStrings backed by a Vector, which adds allocation and hurts hot paths that append a header and payload before reading or copying. Modification: Add an internal two-fragment ByteString representation for simple non-empty fragments, flatten it when appending beyond two fragments, and add focused JMH coverage for append, read, copy, and cross-boundary header reads. Result: JMH shows appendTwo allocation dropping from 160 B/op on main to 32 B/op on this branch, with improved read and copy throughput while ByteStringSpec and MiMa pass. --- .../scala/org/apache/pekko/util/ByteString.scala | 426 ++++++++++++++++++++- .../pekko/util/ByteString_append_Benchmark.scala | 48 +++ 2 files changed, 466 insertions(+), 8 deletions(-) diff --git a/actor/src/main/scala/org/apache/pekko/util/ByteString.scala b/actor/src/main/scala/org/apache/pekko/util/ByteString.scala index 0526d2855d..268ced723b 100644 --- a/actor/src/main/scala/org/apache/pekko/util/ByteString.scala +++ b/actor/src/main/scala/org/apache/pekko/util/ByteString.scala @@ -187,6 +187,8 @@ object ByteString { final class ByteString1C private (private val bytes: Array[Byte]) extends CompactByteString { def apply(idx: Int): Byte = bytes(idx) + private[pekko] override def byteAtUnchecked(offset: Int): Byte = bytes(offset) + override def length: Int = bytes.length // Avoid `iterator` in performance sensitive code, call ops directly on ByteString instead @@ -217,7 +219,13 @@ object ByteString { override def ++(that: ByteString): ByteString = { if (that.isEmpty) this else if (this.isEmpty) that - else toByteString1 ++ that + else + that match { + case b: ByteString1C => ByteString2.fromNonEmptySimpleFragments(this, b) + case b: ByteString1 => ByteString2.fromNonEmptySimpleFragments(this, b) + case bs: ByteString2 => ByteStrings(this, bs) + case bs: ByteStrings => ByteStrings(toByteString1, bs) + } } override def take(n: Int): ByteString = @@ -495,6 +503,8 @@ object ByteString { def apply(idx: Int): Byte = bytes(checkRangeConvert(idx)) + private[pekko] override def byteAtUnchecked(offset: Int): Byte = bytes(startIndex + offset) + // Avoid `iterator` in performance sensitive code, call ops directly on ByteString instead override def iterator: ByteIterator.ByteArrayIterator = ByteIterator.ByteArrayIterator(bytes, startIndex, startIndex + length) @@ -607,11 +617,12 @@ object ByteString { else if (this.isEmpty) that else that match { - case b: ByteString1C => ByteStrings(this, b.toByteString1) + case b: ByteString1C => ByteString2.fromNonEmptySimpleFragments(this, b) case b: ByteString1 => if ((bytes eq b.bytes) && (startIndex + length == b.startIndex)) new ByteString1(bytes, startIndex, length + b.length) - else ByteStrings(this, b) + else ByteString2.fromNonEmptySimpleFragments(this, b) + case bs: ByteString2 => ByteStrings(this, bs) case bs: ByteStrings => ByteStrings(this, bs) } } @@ -842,17 +853,30 @@ object ByteString { private[pekko] object ByteStrings extends Companion { def apply(bytestrings: Vector[ByteString1]): ByteString = - new ByteStrings(bytestrings, bytestrings.foldLeft(0)(_ + _.length)) - - def apply(bytestrings: Vector[ByteString1], length: Int): ByteString = new ByteStrings(bytestrings, length) + apply(bytestrings, bytestrings.foldLeft(0)(_ + _.length)) + + def apply(bytestrings: Vector[ByteString1], length: Int): ByteString = + bytestrings.length match { + case 0 => ByteString.empty + case 1 => bytestrings.head + case 2 => ByteString2(bytestrings(0), bytestrings(1)) + case _ => new ByteStrings(bytestrings, length) + } def apply(b1: ByteString1, b2: ByteString1): ByteString = compare(b1, b2) match { - case 3 => new ByteStrings(Vector(b1, b2), b1.length + b2.length) + case 3 => ByteString2(b1, b2) case 2 => b2 case 1 => b1 case 0 => ByteString.empty } + def apply(b: ByteString, bs: ByteString2): ByteString = compare(b, bs) match { + case 3 => new ByteStrings(addFragments(b, bs), b.length + bs.length) + case 2 => bs + case 1 => b + case 0 => ByteString.empty + } + def apply(b: ByteString1, bs: ByteStrings): ByteString = compare(b, bs) match { case 3 => new ByteStrings(b +: bs.bytestrings, bs.length + b.length) case 2 => bs @@ -860,6 +884,13 @@ object ByteString { case 0 => ByteString.empty } + def apply(bs: ByteString2, b: ByteString): ByteString = compare(bs, b) match { + case 3 => new ByteStrings(addFragments(bs, b), bs.length + b.length) + case 2 => b + case 1 => bs + case 0 => ByteString.empty + } + def apply(bs: ByteStrings, b: ByteString1): ByteString = compare(bs, b) match { case 3 => new ByteStrings(bs.bytestrings :+ b, bs.length + b.length) case 2 => b @@ -867,6 +898,27 @@ object ByteString { case 0 => ByteString.empty } + def apply(bs: ByteStrings, b: ByteString2): ByteString = compare(bs, b) match { + case 3 => new ByteStrings(addFragments(bs, b), bs.length + b.length) + case 2 => b + case 1 => bs + case 0 => ByteString.empty + } + + def apply(b: ByteString2, bs: ByteStrings): ByteString = compare(b, bs) match { + case 3 => new ByteStrings(addFragments(b, bs), b.length + bs.length) + case 2 => bs + case 1 => b + case 0 => ByteString.empty + } + + def apply(b1: ByteString2, b2: ByteString2): ByteString = compare(b1, b2) match { + case 3 => new ByteStrings(addFragments(b1, b2), b1.length + b2.length) + case 2 => b2 + case 1 => b1 + case 0 => ByteString.empty + } + def apply(bs1: ByteStrings, bs2: ByteStrings): ByteString = compare(bs1, bs2) match { case 3 => new ByteStrings(bs1.bytestrings ++ bs2.bytestrings, bs1.length + bs2.length) case 2 => bs2 @@ -881,6 +933,21 @@ object ByteString { else if (b2.isEmpty) 1 else 3 + private[ByteString] def addFragments(first: ByteString, second: ByteString): Vector[ByteString1] = { + val builder = new VectorBuilder[ByteString1] + addFragments(builder, first) + addFragments(builder, second) + builder.result() + } + + private[ByteString] def addFragments(builder: VectorBuilder[ByteString1], byteString: ByteString): Unit = + byteString match { + case b: ByteString1C => builder += b.toByteString1 + case b: ByteString1 => builder += b + case b: ByteString2 => b.addFragmentsTo(builder) + case b: ByteStrings => builder ++= b.bytestrings + } + val SerializationIdentity = 2.toByte def readFromInputStream(is: ObjectInputStream): ByteStrings = { @@ -902,6 +969,338 @@ object ByteString { } } + private[pekko] object ByteString2 { + private[ByteString] def fromNonEmptySimpleFragments(first: ByteString, second: ByteString): ByteString = + new ByteString2(first, second, first.length + second.length) + + def apply(first: ByteString, second: ByteString): ByteString = ByteStrings.compare(first, second) match { + case 3 => + if (isSimpleFragment(first) && isSimpleFragment(second)) + fromNonEmptySimpleFragments(first, second) + else + ByteStrings(ByteStrings.addFragments(first, second), first.length + second.length) + case 2 => second + case 1 => first + case 0 => ByteString.empty + } + + private def isSimpleFragment(byteString: ByteString): Boolean = + byteString match { + case _: ByteString1C | _: ByteString1 => true + case _: ByteString2 | _: ByteStrings => false + } + } + + /** + * A ByteString with exactly two simple fragments. + */ + private[pekko] final class ByteString2 private ( + private val first: ByteString, + private val second: ByteString, + val length: Int) + extends ByteString + with Serializable { + if (first.isEmpty) throw new IllegalArgumentException("first must not be empty") + if (second.isEmpty) throw new IllegalArgumentException("second must not be empty") + + private[this] def firstLength: Int = first.length + + def apply(idx: Int): Byte = + if (0 <= idx && idx < length) { + byteAtUnchecked(idx) + } else throw new IndexOutOfBoundsException(idx.toString) + + /** Avoid `iterator` in performance sensitive code, call ops directly on ByteString instead */ + override def iterator: ByteIterator.MultiByteArrayIterator = + ByteIterator.MultiByteArrayIterator(LazyList(byteArrayIterator(first), byteArrayIterator(second))) + + private def byteArrayIterator(byteString: ByteString): ByteIterator.ByteArrayIterator = + byteString match { + case b: ByteString1C => b.iterator + case b: ByteString1 => b.iterator + case _: ByteString2 | _: ByteStrings => + throw new IllegalStateException("ByteString2 fragments must be compact or sliced ByteStrings") + } + + def ++(that: ByteString): ByteString = { + if (that.isEmpty) this + else if (this.isEmpty) that + else + that match { + case b: ByteString1C => ByteStrings(this, b) + case b: ByteString1 => ByteStrings(this, b) + case bs: ByteString2 => ByteStrings(this, bs) + case bs: ByteStrings => ByteStrings(this, bs) + } + } + + private[pekko] def byteStringCompanion = ByteStrings + + def isCompact: Boolean = false + + override def copyToBuffer(buffer: ByteBuffer): Int = { + val written = first.copyToBuffer(buffer) + if (buffer.hasRemaining) written + second.copyToBuffer(buffer) + else written + } + + def compact: CompactByteString = { + val ar = new Array[Byte](length) + first.copyToArray(ar, 0, firstLength) + second.copyToArray(ar, firstLength, length - firstLength) + ByteString1C(ar) + } + + def asByteBuffer: ByteBuffer = compact.asByteBuffer + + def asByteBuffers: scala.collection.immutable.Iterable[ByteBuffer] = + List(first.asByteBuffer, second.asByteBuffer) + + override def asInputStream: InputStream = + new SequenceInputStream(Iterator(first.asInputStream, second.asInputStream).asJavaEnumeration) + + def decodeString(charset: String): String = compact.decodeString(charset) + + def decodeString(charset: Charset): String = compact.decodeString(charset) + + override def decodeBase64: ByteString = compact.decodeBase64 + + override def encodeBase64: ByteString = compact.encodeBase64 + + private[pekko] def writeToOutputStream(os: ObjectOutputStream): Unit = { + os.writeInt(2) + first.writeToOutputStream(os) + second.writeToOutputStream(os) + } + + override def take(n: Int): ByteString = + if (n <= 0) ByteString.empty + else if (n >= length) this + else { + if (n <= firstLength) first.take(n) + else ByteString2(first, second.take(n - firstLength)) + } + + override def dropRight(n: Int): ByteString = + if (0 < n && n < length) { + val secondLength = length - firstLength + if (n < secondLength) ByteString2(first, second.dropRight(n)) + else first.dropRight(n - secondLength) + } else if (n >= length) ByteString.empty + else this + + override def slice(from: Int, until: Int): ByteString = { + val lo = math.max(from, 0) + val hi = math.min(until, length) + if (lo >= hi) ByteString.empty + else if (lo == 0 && hi == length) this + else drop(lo).take(hi - lo) + } + + override def drop(n: Int): ByteString = + if (n <= 0) this + else if (n >= length) ByteString.empty + else { + if (n < firstLength) ByteString2(first.drop(n), second) + else second.drop(n - firstLength) + } + + override def indexOf[B >: Byte](elem: B, from: Int): Int = + if (from >= length) -1 + else { + val start = math.max(from, 0) + if (start < firstLength) { + val firstIndex = first.indexOf(elem, start) + if (firstIndex >= 0) firstIndex + else { + val secondIndex = second.indexOf(elem, 0) + if (secondIndex >= 0) firstLength + secondIndex else -1 + } + } else { + val secondIndex = second.indexOf(elem, start - firstLength) + if (secondIndex >= 0) firstLength + secondIndex else -1 + } + } + + override def indexOf(elem: Byte, from: Int): Int = + if (from >= length) -1 + else { + val start = math.max(from, 0) + if (start < firstLength) { + val firstIndex = first.indexOf(elem, start) + if (firstIndex >= 0) firstIndex + else { + val secondIndex = second.indexOf(elem, 0) + if (secondIndex >= 0) firstLength + secondIndex else -1 + } + } else { + val secondIndex = second.indexOf(elem, start - firstLength) + if (secondIndex >= 0) firstLength + secondIndex else -1 + } + } + + override def indexOf(elem: Byte, from: Int, to: Int): Int = { + val start = math.max(from, 0) + val end = math.min(to, length) + if (start >= end) -1 + else { + if (start < firstLength) { + val firstIndex = first.indexOf(elem, start, math.min(end, firstLength)) + if (firstIndex >= 0) firstIndex + else if (end > firstLength) { + val secondIndex = second.indexOf(elem, 0, end - firstLength) + if (secondIndex >= 0) firstLength + secondIndex else -1 + } else -1 + } else { + val secondIndex = second.indexOf(elem, start - firstLength, end - firstLength) + if (secondIndex >= 0) firstLength + secondIndex else -1 + } + } + } + + override def lastIndexOf[B >: Byte](elem: B, end: Int): Int = + if (end < 0) -1 + else { + val cappedEnd = math.min(end, length - 1) + if (cappedEnd >= firstLength) { + val secondIndex = second.lastIndexOf(elem, cappedEnd - firstLength) + if (secondIndex >= 0) firstLength + secondIndex + else first.lastIndexOf(elem, firstLength - 1) + } else first.lastIndexOf(elem, cappedEnd) + } + + override def lastIndexOf(elem: Byte, end: Int): Int = + if (end < 0) -1 + else { + val cappedEnd = math.min(end, length - 1) + if (cappedEnd >= firstLength) { + val secondIndex = second.lastIndexOf(elem, cappedEnd - firstLength) + if (secondIndex >= 0) firstLength + secondIndex + else first.lastIndexOf(elem, firstLength - 1) + } else first.lastIndexOf(elem, cappedEnd) + } + + override def copyToArray[B >: Byte](dest: Array[B], start: Int, len: Int): Int = { + val totalToCopy = math.max(0, math.min(math.min(len, length), dest.length - start)) + if (totalToCopy > 0) { + val firstCopied = first.copyToArray(dest, start, totalToCopy) + if (firstCopied < totalToCopy) + second.copyToArray(dest, start + firstCopied, totalToCopy - firstCopied) + } + totalToCopy + } + + override def foreach[@specialized U](f: Byte => U): Unit = { + first.foreach(f) + second.foreach(f) + } + + override def startsWith(bytes: Array[Byte], offset: Int): Boolean = { + val needleLen = bytes.length + if (length - offset < needleLen) false + else { + var i = offset + var j = 0 + while (j < needleLen) { + if (byteAt(i, firstLength) != bytes(j)) return false + i += 1 + j += 1 + } + true + } + } + + override def endsWith(bytes: Array[Byte]): Boolean = { + val needleLen = bytes.length + if (length < needleLen) false + else { + var i = length - needleLen + var j = 0 + while (j < needleLen) { + if (byteAt(i, firstLength) != bytes(j)) return false + i += 1 + j += 1 + } + true + } + } + + private[this] def byteAt(offset: Int, firstLength: Int): Byte = + if (offset < firstLength) first.byteAtUnchecked(offset) else second.byteAtUnchecked(offset - firstLength) + + private[pekko] override def byteAtUnchecked(offset: Int): Byte = { + val firstLength = this.firstLength + byteAt(offset, firstLength) + } + + private[pekko] override def readShortBEUnchecked(offset: Int): Short = { + if (offset + java.lang.Short.BYTES <= firstLength) first.readShortBEUnchecked(offset) + else if (offset >= firstLength) second.readShortBEUnchecked(offset - firstLength) + else ((byteAt(offset, firstLength) & 0xFF) << 8 | (byteAt(offset + 1, firstLength) & 0xFF)).toShort + } + + private[pekko] override def readShortLEUnchecked(offset: Int): Short = { + if (offset + java.lang.Short.BYTES <= firstLength) first.readShortLEUnchecked(offset) + else if (offset >= firstLength) second.readShortLEUnchecked(offset - firstLength) + else ((byteAt(offset, firstLength) & 0xFF) | (byteAt(offset + 1, firstLength) & 0xFF) << 8).toShort + } + + private[pekko] override def readIntBEUnchecked(offset: Int): Int = { + if (offset + java.lang.Integer.BYTES <= firstLength) first.readIntBEUnchecked(offset) + else if (offset >= firstLength) second.readIntBEUnchecked(offset - firstLength) + else + (byteAt(offset, firstLength) & 0xFF) << 24 | + (byteAt(offset + 1, firstLength) & 0xFF) << 16 | + (byteAt(offset + 2, firstLength) & 0xFF) << 8 | + (byteAt(offset + 3, firstLength) & 0xFF) + } + + private[pekko] override def readIntLEUnchecked(offset: Int): Int = { + if (offset + java.lang.Integer.BYTES <= firstLength) first.readIntLEUnchecked(offset) + else if (offset >= firstLength) second.readIntLEUnchecked(offset - firstLength) + else + (byteAt(offset, firstLength) & 0xFF) | + (byteAt(offset + 1, firstLength) & 0xFF) << 8 | + (byteAt(offset + 2, firstLength) & 0xFF) << 16 | + (byteAt(offset + 3, firstLength) & 0xFF) << 24 + } + + private[pekko] override def readLongBEUnchecked(offset: Int): Long = { + if (offset + java.lang.Long.BYTES <= firstLength) first.readLongBEUnchecked(offset) + else if (offset >= firstLength) second.readLongBEUnchecked(offset - firstLength) + else + (byteAt(offset, firstLength).toLong & 0xFFL) << 56 | + (byteAt(offset + 1, firstLength).toLong & 0xFFL) << 48 | + (byteAt(offset + 2, firstLength).toLong & 0xFFL) << 40 | + (byteAt(offset + 3, firstLength).toLong & 0xFFL) << 32 | + (byteAt(offset + 4, firstLength).toLong & 0xFFL) << 24 | + (byteAt(offset + 5, firstLength).toLong & 0xFFL) << 16 | + (byteAt(offset + 6, firstLength).toLong & 0xFFL) << 8 | + (byteAt(offset + 7, firstLength).toLong & 0xFFL) + } + + private[pekko] override def readLongLEUnchecked(offset: Int): Long = { + if (offset + java.lang.Long.BYTES <= firstLength) first.readLongLEUnchecked(offset) + else if (offset >= firstLength) second.readLongLEUnchecked(offset - firstLength) + else + (byteAt(offset, firstLength).toLong & 0xFFL) | + (byteAt(offset + 1, firstLength).toLong & 0xFFL) << 8 | + (byteAt(offset + 2, firstLength).toLong & 0xFFL) << 16 | + (byteAt(offset + 3, firstLength).toLong & 0xFFL) << 24 | + (byteAt(offset + 4, firstLength).toLong & 0xFFL) << 32 | + (byteAt(offset + 5, firstLength).toLong & 0xFFL) << 40 | + (byteAt(offset + 6, firstLength).toLong & 0xFFL) << 48 | + (byteAt(offset + 7, firstLength).toLong & 0xFFL) << 56 + } + + private[pekko] def addFragmentsTo(builder: VectorBuilder[ByteString1]): Unit = { + ByteStrings.addFragments(builder, first) + ByteStrings.addFragments(builder, second) + } + + protected def writeReplace(): AnyRef = new SerializationProxy(this) + } + /** * A ByteString with 2 or more fragments. */ @@ -936,6 +1335,7 @@ object ByteString { that match { case b: ByteString1C => ByteStrings(this, b.toByteString1) case b: ByteString1 => ByteStrings(this, b) + case bs: ByteString2 => ByteStrings(this, bs) case bs: ByteStrings => ByteStrings(this, bs) } } @@ -1928,6 +2328,13 @@ sealed abstract class ByteString readLongLEUnchecked(offset) } + /** + * INTERNAL API + * Fast byte access for callers that have already checked bounds. + */ + private[pekko] def byteAtUnchecked(offset: Int): Byte = + apply(offset) + /** * INTERNAL API * Optimized in subclasses when we have byte arrays where we can use {@link SWARUtil} @@ -2110,7 +2517,7 @@ sealed abstract class CompactByteString extends ByteString with Serializable { final class ByteStringBuilder extends Builder[Byte, ByteString] { builder => - import ByteString.{ ByteString1, ByteString1C, ByteStrings } + import ByteString.{ ByteString1, ByteString1C, ByteString2, ByteStrings } private var _length: Int = 0 private val _builder: VectorBuilder[ByteString1] = new VectorBuilder[ByteString1]() private var _temp: Array[Byte] = _ @@ -2190,6 +2597,9 @@ final class ByteStringBuilder extends Builder[Byte, ByteString] { case b: ByteString1 => _builder += b _length += b.length + case b: ByteString2 => + b.addFragmentsTo(_builder) + _length += b.length case bs: ByteStrings => _builder ++= bs.bytestrings _length += bs.length diff --git a/bench-jmh/src/main/scala/org/apache/pekko/util/ByteString_append_Benchmark.scala b/bench-jmh/src/main/scala/org/apache/pekko/util/ByteString_append_Benchmark.scala index 23e2327271..69262bc48d 100644 --- a/bench-jmh/src/main/scala/org/apache/pekko/util/ByteString_append_Benchmark.scala +++ b/bench-jmh/src/main/scala/org/apache/pekko/util/ByteString_append_Benchmark.scala @@ -13,6 +13,7 @@ package org.apache.pekko.util +import java.nio.ByteBuffer import java.util.concurrent.TimeUnit import org.openjdk.jmh.annotations._ @@ -27,6 +28,15 @@ import org.openjdk.jmh.infra.Blackhole class ByteString_append_Benchmark { private val bs = ByteString(Array.ofDim[Byte](10)) + private val header = ByteString.fromArrayUnsafe(Array.tabulate[Byte](9)(i => (i + 1).toByte)) + private val payload = ByteString.fromArrayUnsafe(Array.fill[Byte](4096)(42)) + private val frame = header ++ payload + private val shortHeaderPrefix = ByteString.fromArrayUnsafe(Array[Byte](1, 2)) + private val headerTailAndPayload = ByteString.fromArrayUnsafe(Array.tabulate[Byte](4096 + 7)(i => (i + 3).toByte)) + private val crossBoundaryFrame = shortHeaderPrefix ++ headerTailAndPayload + private val compactFrame = frame.compact + private val outputBuffer = ByteBuffer.allocateDirect(9 + 4096) + private val outputArray = new Array[Byte](9 + 4096) @Benchmark @OperationsPerInvocation(10000) @@ -55,6 +65,44 @@ class ByteString_append_Benchmark { } bh.consume(result) } + + @Benchmark + def appendTwo(): ByteString = + header ++ payload + + @Benchmark + def readTwoPartHeader(): Int = + frame.readIntBE(0) + frame.readIntBE(4) + frame(8) + + @Benchmark + def readTwoPartCrossBoundaryHeader(): Int = + crossBoundaryFrame.readIntBE(0) + crossBoundaryFrame.readIntBE(2) + crossBoundaryFrame(8) + + @Benchmark + def readCompactHeader(): Int = + compactFrame.readIntBE(0) + compactFrame.readIntBE(4) + compactFrame(8) + + @Benchmark + def appendTwoAndReadHeader(): Int = { + val frame = header ++ payload + frame.readIntBE(0) + frame.readIntBE(4) + frame(8) + } + + @Benchmark + def appendTwoAndReadCrossBoundaryHeader(): Int = { + val frame = shortHeaderPrefix ++ headerTailAndPayload + frame.readIntBE(0) + frame.readIntBE(2) + frame(8) + } + + @Benchmark + def appendTwoAndCopyToBuffer(): Int = { + outputBuffer.clear() + (header ++ payload).copyToBuffer(outputBuffer) + } + + @Benchmark + def appendTwoAndCopyToArray(): Int = (header ++ payload).copyToArray(outputArray, 0, outputArray.length) + @Benchmark @OperationsPerInvocation(10000) def builderOne(bh: Blackhole): Unit = { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
