This is an automated email from the ASF dual-hosted git repository.
pjfanning pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pekko.git
The following commit(s) were added to refs/heads/main by this push:
new f636931e82 ByteString: fix bugs, type safety, and add performance
improvements with tests (#2868)
f636931e82 is described below
commit f636931e82039f767eac19caa5ba1cb0b6f65f20
Author: PJ Fanning <[email protected]>
AuthorDate: Sat Apr 18 17:04:53 2026 +0200
ByteString: fix bugs, type safety, and add performance improvements with
tests (#2868)
* Fix ByteString bugs, type safety, and add performance improvements with
tests
Agent-Logs-Url:
https://github.com/pjfanning/incubator-pekko/sessions/b3305192-a446-4ba8-96ff-9e758725198a
Co-authored-by: pjfanning <[email protected]>
* Address code review: restore Math.max(0, length) in ByteString1.apply for
defensive clarity
Agent-Logs-Url:
https://github.com/pjfanning/incubator-pekko/sessions/b3305192-a446-4ba8-96ff-9e758725198a
Co-authored-by: pjfanning <[email protected]>
* scalafmt
* Fix calculation of copyLength in writeToBuffer
---------
Co-authored-by: copilot-swe-agent[bot]
<[email protected]>
Co-authored-by: pjfanning <[email protected]>
---
.../org/apache/pekko/util/ByteStringSpec.scala | 155 +++++++++++++++++++++
.../scala/org/apache/pekko/util/ByteString.scala | 138 ++++++++++++++----
2 files changed, 264 insertions(+), 29 deletions(-)
diff --git
a/actor-tests/src/test/scala/org/apache/pekko/util/ByteStringSpec.scala
b/actor-tests/src/test/scala/org/apache/pekko/util/ByteStringSpec.scala
index acfbe83280..b9d0f2f28f 100644
--- a/actor-tests/src/test/scala/org/apache/pekko/util/ByteStringSpec.scala
+++ b/actor-tests/src/test/scala/org/apache/pekko/util/ByteStringSpec.scala
@@ -2091,6 +2091,161 @@ class ByteStringSpec extends AnyWordSpec with Matchers
with Checkers {
an[IndexOutOfBoundsException] should be thrownBy bss.readLongBE(-1)
an[IndexOutOfBoundsException] should be thrownBy bss.readLongLE(-1)
}
+
+ "have correct takeRight behaviour for boundary values" in {
+ // empty ByteString
+ ByteString.empty.takeRight(-1) should ===(ByteString.empty)
+ ByteString.empty.takeRight(0) should ===(ByteString.empty)
+ ByteString.empty.takeRight(1) should ===(ByteString.empty)
+ // n <= 0
+ ByteString1.fromString("abc").takeRight(-1) should ===(ByteString.empty)
+ ByteString1.fromString("abc").takeRight(0) should ===(ByteString.empty)
+ ByteString1C.fromString("abc").takeRight(-1) should ===(ByteString.empty)
+ ByteString1C.fromString("abc").takeRight(0) should ===(ByteString.empty)
+ val bssTR = ByteStrings(ByteString1.fromString("ab"),
ByteString1.fromString("cd"))
+ bssTR.takeRight(-1) should ===(ByteString.empty)
+ bssTR.takeRight(0) should ===(ByteString.empty)
+ // n >= length
+ ByteString1.fromString("abc").takeRight(3) should ===(ByteString("abc"))
+ ByteString1.fromString("abc").takeRight(100) should
===(ByteString("abc"))
+ ByteString1C.fromString("abc").takeRight(3) should ===(ByteString("abc"))
+ ByteString1C.fromString("abc").takeRight(100) should
===(ByteString("abc"))
+ bssTR.takeRight(4) should ===(ByteString("abcd"))
+ bssTR.takeRight(100) should ===(ByteString("abcd"))
+ // n in range
+ ByteString1.fromString("abcde").takeRight(2) should ===(ByteString("de"))
+ ByteString1C.fromString("abcde").takeRight(2) should
===(ByteString("de"))
+ bssTR.takeRight(3) should ===(ByteString("bcd"))
+ }
+
+ "throw IndexOutOfBoundsException when apply is called with an
out-of-bounds index" in {
+ val bs1C = ByteString1C(Array[Byte](1, 2, 3))
+ an[IndexOutOfBoundsException] should be thrownBy bs1C(-1)
+ an[IndexOutOfBoundsException] should be thrownBy bs1C(3)
+ val bs1 = ByteString1(Array[Byte](0, 1, 2, 3, 4), 1, 3)
+ an[IndexOutOfBoundsException] should be thrownBy bs1(-1)
+ an[IndexOutOfBoundsException] should be thrownBy bs1(3)
+ val bss = ByteStrings(ByteString1.fromString("ab"),
ByteString1.fromString("cd"))
+ an[IndexOutOfBoundsException] should be thrownBy bss(-1)
+ an[IndexOutOfBoundsException] should be thrownBy bss(4)
+ an[IndexOutOfBoundsException] should be thrownBy ByteString.empty(0)
+ }
+
+ "return 0 from copyToArray when start >= destination length" in {
+ val bs1C = ByteString1C(Array[Byte](1, 2, 3))
+ val dest1C = new Array[Byte](3)
+ bs1C.copyToArray(dest1C, 3, 2) should ===(0)
+ bs1C.copyToArray(dest1C, 5, 2) should ===(0)
+ val bs1 = ByteString1(Array[Byte](0, 1, 2, 3, 4), 1, 3)
+ val dest1 = new Array[Byte](3)
+ bs1.copyToArray(dest1, 3, 2) should ===(0)
+ bs1.copyToArray(dest1, 5, 2) should ===(0)
+ val bss = ByteStrings(ByteString1.fromString("ab"),
ByteString1.fromString("cd"))
+ val destBss = new Array[Byte](4)
+ bss.copyToArray(destBss, 4, 2) should ===(0)
+ bss.copyToArray(destBss, 6, 2) should ===(0)
+ }
+
+ "correctly handle sizeHint with value smaller than already committed
bytes" in {
+ val builder = ByteString.newBuilder
+ builder.putBytes(Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+ // sizeHint smaller than already written should not throw or shrink
+ noException should be thrownBy builder.sizeHint(3)
+ noException should be thrownBy builder.sizeHint(0)
+ noException should be thrownBy builder.sizeHint(-1)
+ builder.result().length should ===(10)
+ }
+
+ "correctly handle grouped with size larger than length" in {
+ val bs = ByteString1.fromString("abc")
+ bs.grouped(10).toList should ===(List(ByteString("abc")))
+ bs.grouped(3).toList should ===(List(ByteString("abc")))
+ ByteString.empty.grouped(5).toList should ===(List.empty)
+ val bss = ByteStrings(ByteString1.fromString("ab"),
ByteString1.fromString("cd"))
+ bss.grouped(100).toList should ===(List(ByteString("abcd")))
+ }
+
+ "map bytes correctly on each concrete type" in {
+ val inc: Byte => Byte = b => (b + 1).toByte
+ // ByteString1C
+ ByteString1C(Array[Byte](1, 2, 3)).map(inc) should
===(ByteString(Array[Byte](2, 3, 4)))
+ // ByteString1 with offset
+ ByteString1(Array[Byte](0, 1, 2, 3, 4), 1, 3).map(inc) should
===(ByteString(Array[Byte](2, 3, 4)))
+ // ByteStrings
+ val bss = ByteStrings(ByteString1.fromString("ab"),
ByteString1.fromString("cd"))
+ bss.map(b => (b + 1).toByte) should ===(ByteString("bcde"))
+ // empty
+ ByteString.empty.map(inc) should ===(ByteString.empty)
+ }
+
+ "foreach visits each byte in order on each concrete type" in {
+ def collect(bs: ByteString): Seq[Byte] = {
+ val buf = scala.collection.mutable.ArrayBuffer.empty[Byte]
+ bs.foreach(buf += _)
+ buf.toSeq
+ }
+ // ByteString1C
+ collect(ByteString1C(Array[Byte](10, 20, 30))) should ===(Seq[Byte](10,
20, 30))
+ // ByteString1 with internal offset
+ collect(ByteString1(Array[Byte](0, 10, 20, 30, 40), 1, 3)) should
===(Seq[Byte](10, 20, 30))
+ // ByteStrings (multi-segment)
+ collect(ByteStrings(ByteString1.fromString("ab"),
ByteString1.fromString("cd"))) should ===(
+ Seq[Byte]('a', 'b', 'c', 'd'))
+ // empty
+ collect(ByteString.empty) should ===(Seq.empty[Byte])
+ }
+
+ "slice returns correct result on ByteString1 with all boundary
combinations" in {
+ val bs = ByteString1.fromString("abcde")
+ bs.slice(0, 5) should ===(ByteString("abcde"))
+ (bs.slice(0, 5) eq bs) should ===(true) // identity when full range
+ bs.slice(1, 4) should ===(ByteString("bcd"))
+ bs.slice(0, 0) should ===(ByteString.empty)
+ bs.slice(-5, 3) should ===(ByteString("abc"))
+ bs.slice(3, 100) should ===(ByteString("de"))
+ bs.slice(-5, -1) should ===(ByteString.empty)
+ bs.slice(4, 2) should ===(ByteString.empty) // from > until
+ // ByteString1 with internal offset
+ val bs2 = ByteString1(Array[Byte](0, 1, 2, 3, 4, 5, 6), 2, 4) //
[2,3,4,5]
+ bs2.slice(1, 3) should ===(ByteString(Array[Byte](3, 4)))
+ bs2.slice(0, 4) should ===(ByteString(Array[Byte](2, 3, 4, 5)))
+ (bs2.slice(0, 4) eq bs2) should ===(true)
+ bs2.slice(-1, 2) should ===(ByteString(Array[Byte](2, 3)))
+ }
+
+ "indexOfSlice handles non-Byte typed Seq safely" in {
+ // Seq[Int] is B >: Byte; should not throw ClassCastException
+ val bs = ByteString("abc")
+ // 'a'.toInt == 97 — value semantics comparison works for Byte == Int
through ==
+ val result = bs.indexOfSlice(Seq[Int](97, 98)) // 'a'=97, 'b'=98
+ result should ===(0)
+ bs.indexOfSlice(Seq[Int](99)) should ===(2) // 'c'=99
+ bs.indexOfSlice(Seq[Int](100)) should ===(-1) // 'd'=100 not present
+ }
+
+ "lastIndexOfSlice handles non-Byte typed Seq safely" in {
+ val bs = ByteString("aabb")
+ val result = bs.lastIndexOfSlice(Seq[Int](97, 97)) // 'a'=97
+ result should ===(0)
+ bs.lastIndexOfSlice(Seq[Int](98)) should ===(3) // 'b'=98
+ bs.lastIndexOfSlice(Seq[Int](100)) should ===(-1) // 'd'=100 not present
+ }
+
+ "ByteString1.apply factory returns canonical empty for non-positive
length" in {
+ (ByteString1(Array[Byte](1, 2, 3), 0, 0) should
be).theSameInstanceAs(ByteString1.empty)
+ (ByteString1(Array[Byte](1, 2, 3), 0, -1) should
be).theSameInstanceAs(ByteString1.empty)
+ (ByteString1(Array[Byte](1, 2, 3), 0, Int.MinValue) should
be).theSameInstanceAs(ByteString1.empty)
+ }
+
+ "ByteStringBuilder.sizeHint does not shrink existing capacity" in {
+ val builder = ByteString.newBuilder
+ builder.sizeHint(100)
+ builder.putByte(42)
+ // sizeHint smaller than current capacity — should not fail or lose data
+ noException should be thrownBy builder.sizeHint(10)
+ noException should be thrownBy builder.sizeHint(0)
+ builder.result() should ===(ByteString(42.toByte))
+ }
}
"A ByteStringIterator" must {
diff --git a/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
b/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
index 9b24d88160..b6ba28d0a3 100644
--- a/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
+++ b/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
@@ -415,7 +415,7 @@ object ByteString {
/** INTERNAL API: Specialized for internal use, writing multiple
ByteString1C into the same ByteBuffer. */
private[pekko] def writeToBuffer(buffer: ByteBuffer, offset: Int): Int = {
- val copyLength = Math.min(buffer.remaining, offset + length)
+ val copyLength = Math.min(buffer.remaining, length - offset)
if (copyLength > 0) {
buffer.put(bytes, offset, copyLength)
}
@@ -428,7 +428,7 @@ object ByteString {
}
override def copyToArray[B >: Byte](dest: Array[B], start: Int, len: Int):
Int = {
- val toCopy = math.min(math.min(len, bytes.length), dest.length - start)
+ val toCopy = math.max(0, math.min(math.min(len, bytes.length),
dest.length - start))
if (toCopy > 0) {
Array.copy(bytes, 0, dest, start, toCopy)
}
@@ -439,6 +439,24 @@ object ByteString {
override def asInputStream: InputStream = new
UnsynchronizedByteArrayInputStream(bytes)
+ override def foreach[@specialized U](f: Byte => U): Unit = {
+ var i = 0
+ while (i < bytes.length) {
+ f(bytes(i))
+ i += 1
+ }
+ }
+
+ override def map[A](f: Byte => Byte): ByteString = {
+ val result = new Array[Byte](bytes.length)
+ var i = 0
+ while (i < bytes.length) {
+ result(i) = f(bytes(i))
+ i += 1
+ }
+ ByteString1C(result)
+ }
+
private[pekko] override def readShortBEUnchecked(offset: Int): Short =
SWARUtil.getShort(bytes, offset, ByteOrder.BIG_ENDIAN)
private[pekko] override def readShortLEUnchecked(offset: Int): Short =
@@ -459,7 +477,7 @@ object ByteString {
def fromString(s: String): ByteString1 =
apply(s.getBytes(StandardCharsets.UTF_8))
def apply(bytes: Array[Byte]): ByteString1 = apply(bytes, 0, bytes.length)
def apply(bytes: Array[Byte], startIndex: Int, length: Int): ByteString1 =
- if (length == 0) empty
+ if (length <= 0) empty
else new ByteString1(bytes, Math.max(0, startIndex), Math.max(0, length))
val SerializationIdentity = 0.toByte
@@ -523,8 +541,13 @@ object ByteString {
if (n >= length) this
else ByteString1(bytes, startIndex, n)
- override def slice(from: Int, until: Int): ByteString =
- drop(from).take(until - Math.max(0, from))
+ override def slice(from: Int, until: Int): ByteString = {
+ val lo = math.max(0, from)
+ val hi = math.min(until, length)
+ if (lo >= hi) ByteString.empty
+ else if (lo == 0 && hi == length) this
+ else ByteString1(bytes, startIndex + lo, hi - lo)
+ }
override def copyToBuffer(buffer: ByteBuffer): Int =
writeToBuffer(buffer)
@@ -767,7 +790,7 @@ object ByteString {
override def copyToArray[B >: Byte](dest: Array[B], start: Int, len: Int):
Int = {
// min of the bytes available to copy, bytes there is room for in dest
and the requested number of bytes
- val toCopy = math.min(math.min(len, length), dest.length - start)
+ val toCopy = math.max(0, math.min(math.min(len, length), dest.length -
start))
if (toCopy > 0) {
Array.copy(bytes, startIndex, dest, start, toCopy)
}
@@ -784,6 +807,25 @@ object ByteString {
override def asInputStream: InputStream =
new UnsynchronizedByteArrayInputStream(bytes, startIndex, length)
+ override def foreach[@specialized U](f: Byte => U): Unit = {
+ var i = startIndex
+ val end = startIndex + length
+ while (i < end) {
+ f(bytes(i))
+ i += 1
+ }
+ }
+
+ override def map[A](f: Byte => Byte): ByteString = {
+ val result = new Array[Byte](length)
+ var i = 0
+ while (i < length) {
+ result(i) = f(bytes(startIndex + i))
+ i += 1
+ }
+ ByteString1C(result)
+ }
+
private[pekko] override def readShortBEUnchecked(offset: Int): Short =
SWARUtil.getShort(bytes, startIndex + offset, ByteOrder.BIG_ENDIAN)
private[pekko] override def readShortLEUnchecked(offset: Int): Short =
@@ -1160,7 +1202,7 @@ object ByteString {
if (bytestrings.size == 1) bytestrings.head.copyToArray(dest, start, len)
else {
// min of the bytes available to copy, bytes there is room for in dest
and the requested number of bytes
- val totalToCopy = math.min(math.min(len, length), dest.length - start)
+ val totalToCopy = math.max(0, math.min(math.min(len, length),
dest.length - start))
if (totalToCopy > 0) {
val bsIterator = bytestrings.iterator
var copied = 0
@@ -1271,7 +1313,8 @@ sealed abstract class ByteString
override def takeWhile(p: Byte => Boolean): ByteString =
iterator.takeWhile(p).toByteString
override def dropWhile(p: Byte => Boolean): ByteString =
iterator.dropWhile(p).toByteString
override def span(p: Byte => Boolean): (ByteString, ByteString) = {
- val (a, b) = iterator.span(p); (a.toByteString, b.toByteString)
+ val (a, b) = iterator.span(p)
+ (a.toByteString, b.toByteString)
}
override def splitAt(n: Int): (ByteString, ByteString) = (take(n), drop(n))
@@ -1361,15 +1404,27 @@ sealed abstract class ByteString
val sliceLength = slice.length
if (sliceLength == 0) if (from > length) -1 else math.max(from, 0)
else {
- val headByte = slice.head.asInstanceOf[Byte]
- @tailrec def rec(from: Int): Int = {
- val startPos = indexOf(headByte, from, length - slice.length + 1)
- if (startPos == -1) -1
- else if (check(startPos)) startPos
- else rec(startPos + 1)
+ slice.head match {
+ case headByte: Byte =>
+ @tailrec def rec(from: Int): Int = {
+ val startPos = indexOf(headByte, from, length - sliceLength + 1)
+ if (startPos == -1) -1
+ else if (check(startPos)) startPos
+ else rec(startPos + 1)
+ }
+ if (sliceLength == 1) indexOf(headByte, from)
+ else rec(math.max(0, from))
+ case headElem =>
+ // Non-Byte head: use generic indexOf which handles any B >: Byte
via equality
+ @tailrec def rec(pos: Int): Int = {
+ val startPos = indexOf(headElem, pos)
+ if (startPos == -1 || startPos > length - sliceLength) -1
+ else if (check(startPos)) startPos
+ else rec(startPos + 1)
+ }
+ if (sliceLength == 1) indexOf(headElem, math.max(0, from))
+ else rec(math.max(0, from))
}
- if (sliceLength == 1) indexOf(headByte, from)
- else rec(math.max(0, from))
}
}
@@ -1428,7 +1483,6 @@ sealed abstract class ByteString
if (sliceLength == 0) if (end < 0) -1 else math.min(end, length)
else if (sliceLength > length) -1
else {
- val tailByte = slice(sliceLength - 1).asInstanceOf[Byte]
// Check all bytes of the slice except the last one (which was matched
by lastIndexOf)
def check(startPos: Int): Boolean = {
var i = startPos
@@ -1444,17 +1498,33 @@ sealed abstract class ByteString
// Cap end to the max valid slice start position to avoid Int overflow
when end is very large
val effectiveEnd = math.min(end, length - sliceLength)
val maxEndPos = effectiveEnd + sliceLength - 1
- @tailrec def rec(currEnd: Int): Int = {
- val endPos = lastIndexOf(tailByte, currEnd)
- if (endPos < sliceLength - 1) -1
- else {
- val startPos = endPos - sliceLength + 1
- if (check(startPos)) startPos
- else rec(endPos - 1)
- }
+ slice(sliceLength - 1) match {
+ case tailByte: Byte =>
+ @tailrec def rec(currEnd: Int): Int = {
+ val endPos = lastIndexOf(tailByte, currEnd)
+ if (endPos < sliceLength - 1) -1
+ else {
+ val startPos = endPos - sliceLength + 1
+ if (check(startPos)) startPos
+ else rec(endPos - 1)
+ }
+ }
+ if (sliceLength == 1) lastIndexOf(tailByte, effectiveEnd)
+ else rec(maxEndPos)
+ case tailElem =>
+ // Non-Byte tail: use generic lastIndexOf which handles any B >:
Byte via equality
+ @tailrec def rec(currEnd: Int): Int = {
+ val endPos = lastIndexOf(tailElem, currEnd)
+ if (endPos < sliceLength - 1) -1
+ else {
+ val startPos = endPos - sliceLength + 1
+ if (check(startPos)) startPos
+ else rec(endPos - 1)
+ }
+ }
+ if (sliceLength == 1) lastIndexOf(tailElem, effectiveEnd)
+ else rec(maxEndPos)
}
- if (sliceLength == 1) lastIndexOf(tailByte, effectiveEnd)
- else rec(maxEndPos)
}
}
@@ -1925,7 +1995,16 @@ sealed abstract class ByteString
(apply(offset + 6).toLong & 0xFF) << 48 |
(apply(offset + 7).toLong & 0xFF) << 56
- def map[A](f: Byte => Byte): ByteString = fromSpecific(super.map(f))
+ def map[A](f: Byte => Byte): ByteString = {
+ val b = ByteString.newBuilder
+ b.sizeHint(length)
+ var i = 0
+ while (i < length) {
+ b += f(apply(i))
+ i += 1
+ }
+ b.result()
+ }
}
object CompactByteString {
@@ -2058,7 +2137,8 @@ final class ByteStringBuilder extends Builder[Byte,
ByteString] {
def length: Int = _length
override def sizeHint(len: Int): Unit = {
- resizeTemp(len - (_length - _tempLength))
+ val needed = len - (_length - _tempLength)
+ if (needed > _tempCapacity) resizeTemp(needed)
}
private def clearTemp(): Unit = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]