This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 42f1a08c1 [CELEBORN-2284] Fix TLS Memory Leak
42f1a08c1 is described below
commit 42f1a08c1bcf0ba117d05d91a766d1214e5b881f
Author: Aravind Patnam <[email protected]>
AuthorDate: Mon Mar 30 14:20:21 2026 +0800
[CELEBORN-2284] Fix TLS Memory Leak
### What changes were proposed in this pull request?
While running jobs with TLS enabled, we encountered memory leaks which
cause worker OOMs.
```
26/02/13 21:02:52,779 ERROR [push-server-9-9] ResourceLeakDetector: LEAK:
ByteBuf.release() was not called before it's garbage-collected. See
https://netty.io/wiki/reference-counted-objects.html for more information.
Recent access records:
Created at:
io.netty.buffer.AbstractByteBufAllocator.compositeDirectBuffer(AbstractByteBufAllocator.java:224)
io.netty.buffer.AbstractByteBufAllocator.compositeBuffer(AbstractByteBufAllocator.java:202)
org.apache.celeborn.common.network.util.TransportFrameDecoder.decodeNext(TransportFrameDecoder.java:143)
org.apache.celeborn.common.network.util.TransportFrameDecoder.channelRead(TransportFrameDecoder.java:66)
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
io.netty.handler.ssl.SslHandler.unwrap(SslHandler.java:1475)
io.netty.handler.ssl.SslHandler.decodeJdkCompatible(SslHandler.java:1338)
io.netty.handler.ssl.SslHandler.decode(SslHandler.java:1387)
io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:529)
io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:468)
io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:290)
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1410)
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:440)
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:919)
io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:166)
io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:788)
io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:724)
io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:650)
io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:562)
io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:997)
io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
java.base/java.lang.Thread.run(Thread.java:840)
```
When a Celeborn worker receives a PushData or PushMergedData message, it
replicates that frame to a secondary worker for fault tolerance. On an
SSL-enabled cluster this replication goes through SslMessageEncoder.encode().
Here is the flow of what happens inside SslMessageEncoder.encode():
- The encoder asks the message body for an SSL-friendly copy by calling
convertToNettyForSsl(). For shuffle data, the body is a NettyManagedBuffer —
data already loaded in off-heap memory. This call runs
buf.duplicate().retain(), which creates a second reference to the same memory
and increments the reference count from 1 to 2.
- The encoder places this second reference inside a composite buffer and
hands it to Netty for writing.
- Netty writes the composite to the network, then releases it —
decrementing the count from 2 to 1.
- Nothing releases the original NettyManagedBuffer's hold on the data, so
the count stays at 1 forever.
- This results in every replicated PushData frame leaking a chunk of
off-heap memory, eventually causing OOM and worker crash.
The fix for this issue is to release the original message body, so that the
net reference count is preserved. The second reference — now living inside the
composite buffer in out — keeps the memory alive while Netty writes it to the
network. When Netty finishes and releases the composite, the count reaches 0
and the memory is freed cleanly.
This is exactly what the non-SSL MessageEncoder already does via
MessageWithHeader.deallocate() — the SSL path simply needed to replicate that
behavior explicitly.
### Why are the changes needed?
fix memory leak
### Does this PR resolve a correctness bug?
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
already internally in production and tested.
Also added unit tests
Closes #3630 from akpatnam25/CELEBORN-2284.
Authored-by: Aravind Patnam <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
.../protocol/EncryptedMessageWithHeader.java | 82 ++++---
.../common/network/protocol/SslMessageEncoder.java | 15 +-
.../protocol/EncryptedMessageWithHeaderSuiteJ.java | 45 +++-
.../network/protocol/SslMessageEncoderSuiteJ.java | 85 +++++++
.../cluster/SslClusterReadWriteLeakSuite.scala | 253 +++++++++++++++++++++
5 files changed, 423 insertions(+), 57 deletions(-)
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
index df2ab1a92..38777f5f8 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
@@ -25,17 +25,19 @@ import javax.annotation.Nullable;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.stream.ChunkedInput;
import io.netty.handler.stream.ChunkedStream;
+import io.netty.util.ReferenceCountUtil;
import org.apache.celeborn.common.network.buffer.ManagedBuffer;
/**
* A wrapper message that holds two separate pieces (a header and a body).
*
- * <p>The header must be a ByteBuf, while the body can be any InputStream or
ChunkedStream Based on
- *
common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
+ * <p>The header must be a ByteBuf, while the body can be a ByteBuf,
InputStream, or ChunkedStream.
+ * Based on
common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
*/
public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {
@@ -61,8 +63,8 @@ public class EncryptedMessageWithHeader implements
ChunkedInput<ByteBuf> {
public EncryptedMessageWithHeader(
@Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long
bodyLength) {
Preconditions.checkArgument(
- body instanceof InputStream || body instanceof ChunkedStream,
- "Body must be an InputStream or a ChunkedStream.");
+ body instanceof ByteBuf || body instanceof InputStream || body
instanceof ChunkedStream,
+ "Body must be a ByteBuf, an InputStream, or a ChunkedStream.");
this.managedBuffer = managedBuffer;
this.header = header;
this.headerLength = header.readableBytes();
@@ -82,40 +84,47 @@ public class EncryptedMessageWithHeader implements
ChunkedInput<ByteBuf> {
return null;
}
- if (totalBytesTransferred < headerLength) {
- totalBytesTransferred += headerLength;
- return header.retain();
- } else if (body instanceof InputStream) {
- InputStream stream = (InputStream) body;
- int available = stream.available();
- if (available <= 0) {
- available = (int) (length() - totalBytesTransferred);
- } else {
- available = (int) Math.min(available, length() -
totalBytesTransferred);
- }
- ByteBuf buffer = allocator.buffer(available);
- int toRead = Math.min(available, buffer.writableBytes());
- int read = buffer.writeBytes(stream, toRead);
- if (read >= 0) {
- totalBytesTransferred += read;
- return buffer;
- } else {
- throw new EOFException("Unable to read bytes from InputStream");
- }
- } else if (body instanceof ChunkedStream) {
- ChunkedStream stream = (ChunkedStream) body;
- long old = stream.transferredBytes();
- ByteBuf buffer = stream.readChunk(allocator);
- long read = stream.transferredBytes() - old;
- if (read >= 0) {
- totalBytesTransferred += read;
- assert (totalBytesTransferred <= length());
- return buffer;
+ if (body instanceof ByteBuf) {
+ // For ByteBuf bodies, return header + body as a single composite buffer.
+ ByteBuf bodyBuf = (ByteBuf) body;
+ totalBytesTransferred = headerLength + bodyLength;
+ return Unpooled.wrappedBuffer(header.retain(), bodyBuf.retain());
+ } else {
+ if (totalBytesTransferred < headerLength) {
+ totalBytesTransferred += headerLength;
+ return header.retain();
+ } else if (body instanceof InputStream) {
+ InputStream stream = (InputStream) body;
+ int available = stream.available();
+ if (available <= 0) {
+ available = (int) (length() - totalBytesTransferred);
+ } else {
+ available = (int) Math.min(available, length() -
totalBytesTransferred);
+ }
+ ByteBuf buffer = allocator.buffer(available);
+ int toRead = Math.min(available, buffer.writableBytes());
+ int read = buffer.writeBytes(stream, toRead);
+ if (read >= 0) {
+ totalBytesTransferred += read;
+ return buffer;
+ } else {
+ throw new EOFException("Unable to read bytes from InputStream");
+ }
+ } else if (body instanceof ChunkedStream) {
+ ChunkedStream stream = (ChunkedStream) body;
+ long old = stream.transferredBytes();
+ ByteBuf buffer = stream.readChunk(allocator);
+ long read = stream.transferredBytes() - old;
+ if (read >= 0) {
+ totalBytesTransferred += read;
+ assert (totalBytesTransferred <= length());
+ return buffer;
+ } else {
+ throw new EOFException("Unable to read bytes from ChunkedStream");
+ }
} else {
- throw new EOFException("Unable to read bytes from ChunkedStream");
+ return null;
}
- } else {
- return null;
}
}
@@ -137,6 +146,7 @@ public class EncryptedMessageWithHeader implements
ChunkedInput<ByteBuf> {
@Override
public void close() throws Exception {
header.release();
+ ReferenceCountUtil.release(body);
if (managedBuffer != null) {
managedBuffer.release();
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
index 508b6a13d..cb3f0ed5c 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
@@ -17,15 +17,12 @@
package org.apache.celeborn.common.network.protocol;
-import java.io.InputStream;
import java.util.List;
import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageEncoder;
-import io.netty.handler.stream.ChunkedStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -89,15 +86,9 @@ public final class SslMessageEncoder extends
MessageToMessageEncoder<Message> {
assert header.writableBytes() == 0;
if (body != null && bodyLength > 0) {
- if (body instanceof ByteBuf) {
- out.add(Unpooled.wrappedBuffer(header, (ByteBuf) body));
- } else if (body instanceof InputStream || body instanceof ChunkedStream)
{
- // For now, assume the InputStream is doing proper chunking.
- out.add(new EncryptedMessageWithHeader(in.body(), header, body,
bodyLength));
- } else {
- throw new IllegalArgumentException(
- "Body must be a ByteBuf, ChunkedStream or an InputStream");
- }
+ // We transfer ownership of the reference on in.body() to
EncryptedMessageWithHeader.
+ // This reference will be freed when EncryptedMessageWithHeader.close()
is called.
+ out.add(new EncryptedMessageWithHeader(in.body(), header, body,
bodyLength));
} else {
out.add(header);
}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
index 508e9d45d..1b96b307a 100644
---
a/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
+++
b/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
@@ -19,7 +19,6 @@ package org.apache.celeborn.common.network.protocol;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayInputStream;
@@ -146,15 +145,43 @@ public class EncryptedMessageWithHeaderSuiteJ {
assertEquals(0, header.refCnt());
}
+ // Tests the case where the body is a ByteBuf and that we manage the
refcounts of the
+ // header, body, and managed buffer properly
@Test
- public void testByteBufIsNotSupported() throws Exception {
- // Validate that ByteBufs are not supported. This test can be updated
- // when we add support for them
+ public void testByteBufBodyFromManagedBuffer() throws Exception {
+ byte[] randomData = new byte[128];
+ new Random().nextBytes(randomData);
+ ByteBuf sourceBuffer = Unpooled.copiedBuffer(randomData);
+ // convertToNettyForSsl() returns buf.duplicate().retain(), simulate that
here
+ ByteBuf body = sourceBuffer.duplicate().retain();
ByteBuf header = Unpooled.copyLong(42);
- assertThrows(
- IllegalArgumentException.class,
- () -> {
- EncryptedMessageWithHeader msg = new
EncryptedMessageWithHeader(null, header, header, 4);
- });
+
+ long expectedHeaderValue = header.getLong(header.readerIndex());
+ assertEquals(1, header.refCnt());
+ assertEquals(2, sourceBuffer.refCnt()); // original + duplicate retain
+ ManagedBuffer managedBuf = new NettyManagedBuffer(sourceBuffer);
+
+ EncryptedMessageWithHeader msg =
+ new EncryptedMessageWithHeader(managedBuf, header, body,
managedBuf.size());
+ ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
+
+ assertFalse(msg.isEndOfInput());
+
+ // Single read should return header + body as a composite buffer
+ ByteBuf result = msg.readChunk(allocator);
+ assertEquals(header.capacity() + randomData.length,
result.readableBytes());
+ assertEquals(expectedHeaderValue, result.readLong());
+ for (int i = 0; i < randomData.length; i++) {
+ assertEquals(randomData[i], result.readByte());
+ }
+ assertTrue(msg.isEndOfInput());
+
+ // Release the chunk (simulates Netty writing it out)
+ result.release();
+
+ // Closing the message should release the source buffer via
managedBuffer.release()
+ msg.close();
+ assertEquals(0, sourceBuffer.refCnt());
+ assertEquals(0, header.refCnt());
}
}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/protocol/SslMessageEncoderSuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/protocol/SslMessageEncoderSuiteJ.java
new file mode 100644
index 000000000..0eaffc3ba
--- /dev/null
+++
b/common/src/test/java/org/apache/celeborn/common/network/protocol/SslMessageEncoderSuiteJ.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.util.ReferenceCountUtil;
+import org.junit.Test;
+
+import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
+
+/**
+ * Verifies reference counting correctness in SslMessageEncoder.encode() for
messages with a
+ * NettyManagedBuffer body.
+ *
+ * <p>When convertToNettyForSsl() returns a ByteBuf, the encoder wraps it in an
+ * EncryptedMessageWithHeader whose close() releases the ManagedBuffer. This
mirrors the non-SSL
+ * MessageEncoder which uses MessageWithHeader.deallocate().
+ */
+public class SslMessageEncoderSuiteJ {
+
+ /**
+ * Core regression test: encoding a PushData with a NettyManagedBuffer body
must leave the
+ * underlying ByteBuf at refCnt=0 after Netty reads and closes the
EncryptedMessageWithHeader.
+ */
+ @Test
+ public void testNettyManagedBufferBodyIsReleasedAfterEncoding() throws
Exception {
+ ByteBuf bodyBuf = Unpooled.copyLong(1L);
+ assertEquals(1, bodyBuf.refCnt());
+
+ PushData pushData =
+ new PushData((byte) 0, "shuffleKey", "partitionId", new
NettyManagedBuffer(bodyBuf));
+
+ ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+ when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
+
+ List<Object> out = new ArrayList<>();
+ SslMessageEncoder.INSTANCE.encode(ctx, pushData, out);
+
+ assertEquals(1, out.size());
+ assertTrue(out.get(0) instanceof EncryptedMessageWithHeader);
+
+ EncryptedMessageWithHeader msg = (EncryptedMessageWithHeader) out.get(0);
+
+ // convertToNettyForSsl() called retain on a duplicate, so refCnt is 2
+ // (original + duplicate). The ManagedBuffer has not been released yet —
that
+ // happens when close() is called.
+ assertEquals(2, bodyBuf.refCnt());
+
+ // Simulate Netty's ChunkedWriteHandler: read the chunk, then release it.
+ ByteBuf chunk = msg.readChunk(UnpooledByteBufAllocator.DEFAULT);
+ assertNotNull(chunk);
+ assertTrue(msg.isEndOfInput());
+ ReferenceCountUtil.release(chunk);
+
+ // Simulate Netty closing the ChunkedInput after transfer completes.
+ msg.close();
+
+ // After close(), the ManagedBuffer is released, bringing refCnt to 0.
+ assertEquals(0, bodyBuf.refCnt());
+ }
+}
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/SslClusterReadWriteLeakSuite.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/SslClusterReadWriteLeakSuite.scala
new file mode 100644
index 000000000..8c40e4a37
--- /dev/null
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/SslClusterReadWriteLeakSuite.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.cluster
+
+import java.io.ByteArrayOutputStream
+import java.nio.charset.StandardCharsets
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.JavaConverters._
+
+import io.netty.buffer.UnpooledByteBufAllocator
+import io.netty.util.{ResourceLeakDetector, ResourceLeakDetectorFactory}
+import org.apache.commons.lang3.RandomStringUtils
+import org.junit.Assert
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
+import org.apache.celeborn.client.read.MetricsCallback
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.identity.UserIdentifier
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs
+import org.apache.celeborn.common.protocol.{CompressionCodec,
TransportModuleConstants}
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+/**
+ * Integration test verifying that the SslMessageEncoder memory-leak fix holds
under a realistic
+ * mini-cluster workload with TLS enabled on every transport channel.
+ *
+ * <p>The test installs a custom Netty ResourceLeakDetector in PARANOID mode
before the cluster
+ * starts, runs a full push+replicate+read shuffle cycle with payloads large
enough to require
+ * multi-record TLS framing (> 16 KB per push), then forces GC and asserts
that the detector
+ * reported zero leaks.
+ */
+class SslClusterReadWriteLeakSuite
+ extends AnyFunSuite
+ with MiniClusterFeature
+ with BeforeAndAfterAll
+ with Logging {
+
+ private val reportedLeaks = new AtomicInteger(0)
+
+ private var previousLeakLevel: ResourceLeakDetector.Level = _
+ private var previousLeakFactory: ResourceLeakDetectorFactory = _
+ private var testMasterPort: Int = _
+
+ private lazy val serverSslConf: Map[String, String] = {
+ val modules = Seq(
+ TransportModuleConstants.PUSH_MODULE,
+ TransportModuleConstants.REPLICATE_MODULE,
+ TransportModuleConstants.FETCH_MODULE)
+ modules
+ .flatMap(m =>
SslSampleConfigs.createDefaultConfigMapForModule(m).asScala.toSeq)
+ .toMap
+ }
+
+ override def beforeAll(): Unit = {
+
+ // Capture the original leak detection settings so we can restore them in
afterAll().
+ previousLeakLevel = ResourceLeakDetector.getLevel
+ previousLeakFactory = ResourceLeakDetectorFactory.instance()
+
+ // Install the leak-counting detector BEFORE any Netty buffers are
allocated so that
+ // AbstractByteBuf.leakDetector (a static final field) is initialized with
our instance
+ // rather than the default one.
+ installLeakCountingDetector()
+
+ testMasterPort = selectRandomPort()
+ val masterInternalPort = selectRandomPort()
+
+ val masterConf = Map(
+ CelebornConf.MASTER_HOST.key -> "localhost",
+ CelebornConf.PORT_MAX_RETRY.key -> "0",
+ CelebornConf.MASTER_PORT.key -> testMasterPort.toString,
+ CelebornConf.MASTER_ENDPOINTS.key -> s"localhost:$testMasterPort",
+ CelebornConf.MASTER_INTERNAL_PORT.key -> masterInternalPort.toString,
+ CelebornConf.MASTER_INTERNAL_ENDPOINTS.key ->
s"localhost:$masterInternalPort") ++ serverSslConf
+
+ val workerConf = Map(
+ CelebornConf.MASTER_ENDPOINTS.key -> s"localhost:$testMasterPort",
+ CelebornConf.MASTER_INTERNAL_ENDPOINTS.key ->
s"localhost:$masterInternalPort") ++ serverSslConf
+
+ setupMiniClusterWithRandomPorts(masterConf, workerConf)
+ }
+
+ override def afterAll(): Unit = {
+ shutdownMiniCluster()
+ // Restore the original leak detection settings so that other suites
running
+ // in the same JVM (forkMode=once) are not affected.
+ ResourceLeakDetector.setLevel(previousLeakLevel)
+
ResourceLeakDetectorFactory.setResourceLeakDetectorFactory(previousLeakFactory)
+ }
+
+ //
---------------------------------------------------------------------------
+
+ test("SSL mini-cluster: push+replicate+fetch large data produces no ByteBuf
memory leaks") {
+ // Verify that our custom leak-counting detector is the one actually
installed in
+ // AbstractByteBuf.leakDetector (a static final field). In a shared JVM
(scalatest
+ // forkMode=once), an earlier suite may have triggered class loading of
AbstractByteBuf,
+ // causing the default detector to be installed instead of ours. In that
case, skip
+ // the test rather than silently producing false negatives.
+ val field =
classOf[io.netty.buffer.AbstractByteBuf].getDeclaredField("leakDetector")
+ field.setAccessible(true)
+ val detector = field.get(null)
+ assume(
+ detector.getClass.getEnclosingClass ==
classOf[SslClusterReadWriteLeakSuite],
+ "Leak-counting detector is not active — AbstractByteBuf.leakDetector was
" +
+ "initialised by an earlier test in this JVM. Skipping leak
assertions.")
+
+ val app = "app-ssl-leak-test"
+ val clientConf = buildSslClientConf(app)
+ val lifecycleManager = new LifecycleManager(app, clientConf)
+ val shuffleClient = new ShuffleClientImpl(app, clientConf,
UserIdentifier("mock", "mock"))
+ shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
+
+ try {
+ // Payloads > 16 KB force TransportFrameDecoder.decodeNext() to assemble
a
+ // CompositeByteBuf from multiple TLS records – this is to prevent a
leaked direct ByteBuf
+ // in SslMessageEncoder.encode().
+ val payload32k = RandomStringUtils.random(32 *
1024).getBytes(StandardCharsets.UTF_8)
+ val payload64k = RandomStringUtils.random(64 *
1024).getBytes(StandardCharsets.UTF_8)
+ val payloadSmall =
RandomStringUtils.random(512).getBytes(StandardCharsets.UTF_8)
+
+ // Push via the primary push path (exercises push + replicate channels).
+ shuffleClient.pushData(1, 0, 0, 0, payload32k, 0, payload32k.length, 1,
1)
+ shuffleClient.pushData(1, 0, 0, 0, payload64k, 0, payload64k.length, 1,
1)
+
+ // Also exercise mergeData + pushMergedData.
+ shuffleClient.mergeData(1, 0, 0, 0, payload32k, 0, payload32k.length, 1,
1)
+ shuffleClient.mergeData(1, 0, 0, 0, payloadSmall, 0,
payloadSmall.length, 1, 1)
+ shuffleClient.pushMergedData(1, 0, 0)
+ Thread.sleep(500)
+
+ shuffleClient.mapperEnd(1, 0, 0, 1, 1)
+
+ // Read back via the fetch channel and verify total byte count.
+ val expectedBytes =
+ payload32k.length + payload64k.length + payload32k.length +
payloadSmall.length
+
+ val metricsCallback = new MetricsCallback {
+ override def incBytesRead(bytesWritten: Long): Unit = {}
+ override def incReadTime(time: Long): Unit = {}
+ }
+ val inputStream =
+ shuffleClient.readPartition(1, 0, 0, 0L, 0, Integer.MAX_VALUE,
metricsCallback)
+ try {
+ val output = new ByteArrayOutputStream()
+ var b = inputStream.read()
+ while (b != -1) {
+ output.write(b)
+ b = inputStream.read()
+ }
+
+ Assert.assertEquals(expectedBytes, output.size())
+ } finally {
+ inputStream.close()
+ }
+ } finally {
+ Thread.sleep(2000) // let in-flight replication finish before shutdown
+ shuffleClient.shutdown()
+ lifecycleManager.rpcEnv.shutdown()
+ }
+
+ // Trigger GC and make Netty poll its queue.
+ triggerLeakDetection()
+
+ Assert.assertEquals(0, reportedLeaks.get())
+ }
+
+ /**
+ * Installs a custom ResourceLeakDetectorFactory whose detectors override
+ * reportTracedLeak/reportUntracedLeak to count every leak report in
reportedLeaks.
+ * Must be called before any ByteBuf is allocated so that
AbstractByteBuf.leakDetector
+ * (static final) is initialised with our instance.
+ */
+ private def installLeakCountingDetector(): Unit = {
+ ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID)
+ val counter = reportedLeaks
+ ResourceLeakDetectorFactory.setResourceLeakDetectorFactory(
+ new ResourceLeakDetectorFactory() {
+ override def newResourceLeakDetector[T](
+ resource: Class[T],
+ samplingInterval: Int,
+ maxActive: Long): ResourceLeakDetector[T] = {
+ new ResourceLeakDetector[T](resource, samplingInterval) {
+ override protected def reportTracedLeak(
+ resourceType: String,
+ records: String): Unit = {
+ super.reportTracedLeak(resourceType, records)
+ counter.incrementAndGet()
+ }
+ override protected def reportUntracedLeak(resourceType: String):
Unit = {
+ super.reportUntracedLeak(resourceType)
+ counter.incrementAndGet()
+ }
+ }
+ }
+ })
+ }
+
+ /**
+ * Builds client CelebornConf with SSL enabled on the "data" module,
matching the production
+ * client-side configuration (spark.celeborn.ssl.data.enabled=true).
ShuffleClientImpl uses the
+ * DATA_MODULE ("data") for all its data-plane connections (push + fetch) to
workers.
+ */
+ private def buildSslClientConf(app: String): CelebornConf = {
+ val clientSslConf =
+ SslSampleConfigs.createDefaultConfigMapForModule(
+ TransportModuleConstants.DATA_MODULE).asScala.toMap
+
+ val conf = new CelebornConf()
+ .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$testMasterPort")
+ .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true")
+ .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K")
+ .set("celeborn.data.io.numConnectionsPerPeer", "1")
+ .set(CelebornConf.SHUFFLE_COMPRESSION_CODEC.key,
CompressionCodec.NONE.name)
+ clientSslConf.foreach { case (k, v) => conf.set(k, v) }
+ conf
+ }
+
+ /**
+ * Forces several rounds of GC and allocates direct buffers in between so
that Netty's
+ * ResourceLeakDetector (PARANOID mode) polls its PhantomReference queue and
reports any
+ * buffers that were GC'd without being released.
+ */
+ private def triggerLeakDetection(): Unit = {
+ for (_ <- 1 to 5) {
+ System.gc()
+ System.runFinalization()
+ Thread.sleep(500)
+ // Each directBuffer() allocation causes the detector to drain its ref
queue.
+ val bufs = (1 to 200).map(_ =>
UnpooledByteBufAllocator.DEFAULT.directBuffer(1))
+ bufs.foreach(_.release())
+ Thread.sleep(200)
+ }
+ }
+}