This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 42f1a08c1 [CELEBORN-2284] Fix TLS Memory Leak
42f1a08c1 is described below

commit 42f1a08c1bcf0ba117d05d91a766d1214e5b881f
Author: Aravind Patnam <[email protected]>
AuthorDate: Mon Mar 30 14:20:21 2026 +0800

    [CELEBORN-2284] Fix TLS Memory Leak
    
    ### What changes were proposed in this pull request?
    While running jobs with TLS enabled, we encountered memory leaks which 
cause worker OOMs.
    ```
    26/02/13 21:02:52,779 ERROR [push-server-9-9] ResourceLeakDetector: LEAK: 
ByteBuf.release() was not called before it's garbage-collected. See 
https://netty.io/wiki/reference-counted-objects.html for more information.
    Recent access records:
    Created at:
            
io.netty.buffer.AbstractByteBufAllocator.compositeDirectBuffer(AbstractByteBufAllocator.java:224)
            
io.netty.buffer.AbstractByteBufAllocator.compositeBuffer(AbstractByteBufAllocator.java:202)
            
org.apache.celeborn.common.network.util.TransportFrameDecoder.decodeNext(TransportFrameDecoder.java:143)
            
org.apache.celeborn.common.network.util.TransportFrameDecoder.channelRead(TransportFrameDecoder.java:66)
            
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
            
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            io.netty.handler.ssl.SslHandler.unwrap(SslHandler.java:1475)
            
io.netty.handler.ssl.SslHandler.decodeJdkCompatible(SslHandler.java:1338)
            io.netty.handler.ssl.SslHandler.decode(SslHandler.java:1387)
            
io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:529)
            
io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:468)
            
io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:290)
            
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
            
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            
io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1410)
            
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:440)
            
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            
io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:919)
            
io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:166)
            
io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:788)
            
io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:724)
            
io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:650)
            io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:562)
            
io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:997)
            
io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
            
io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
            java.base/java.lang.Thread.run(Thread.java:840)
    
    ```
    
    When a Celeborn worker receives a PushData or PushMergedData message, it 
replicates that frame to a secondary worker for fault tolerance. On an 
SSL-enabled cluster this replication goes through SslMessageEncoder.encode(). 
Here is the flow of what happens inside SslMessageEncoder.encode():
    
    - The encoder asks the message body for an SSL-friendly copy by calling 
convertToNettyForSsl(). For shuffle data, the body is a NettyManagedBuffer — 
data already loaded in off-heap memory. This call runs 
buf.duplicate().retain(), which creates a second reference to the same memory 
and increments the reference count from 1 to 2.
    
    - The encoder places this second reference inside a composite buffer and 
hands it to Netty for writing.
    
    - Netty writes the composite to the network, then releases it — 
decrementing the count from 2 to 1.
    
    - Nothing releases the original NettyManagedBuffer's hold on the data, so 
the count stays at 1 forever.
    
    - This results in every replicated PushData frame leaking a chunk of 
off-heap memory, eventually causing OOM and worker crash.
    
    The fix for this issue is to release the original message body, so that the 
net reference count is preserved. The second reference — now living inside the 
composite buffer in out — keeps the memory alive while Netty writes it to the 
network. When Netty finishes and releases the composite, the count reaches 0 
and the memory is freed cleanly.
    
    This is exactly what the non-SSL MessageEncoder already does via 
MessageWithHeader.deallocate() — the SSL path simply needed to replicate that 
behavior explicitly.
    
    ### Why are the changes needed?
    fix memory leak
    
    ### Does this PR resolve a correctness bug?
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    already internally in production and tested.
    Also added unit tests
    
    Closes #3630 from akpatnam25/CELEBORN-2284.
    
    Authored-by: Aravind Patnam <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 .../protocol/EncryptedMessageWithHeader.java       |  82 ++++---
 .../common/network/protocol/SslMessageEncoder.java |  15 +-
 .../protocol/EncryptedMessageWithHeaderSuiteJ.java |  45 +++-
 .../network/protocol/SslMessageEncoderSuiteJ.java  |  85 +++++++
 .../cluster/SslClusterReadWriteLeakSuite.scala     | 253 +++++++++++++++++++++
 5 files changed, 423 insertions(+), 57 deletions(-)

diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
index df2ab1a92..38777f5f8 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
@@ -25,17 +25,19 @@ import javax.annotation.Nullable;
 import com.google.common.base.Preconditions;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.handler.stream.ChunkedInput;
 import io.netty.handler.stream.ChunkedStream;
+import io.netty.util.ReferenceCountUtil;
 
 import org.apache.celeborn.common.network.buffer.ManagedBuffer;
 
 /**
  * A wrapper message that holds two separate pieces (a header and a body).
  *
- * <p>The header must be a ByteBuf, while the body can be any InputStream or 
ChunkedStream Based on
- * 
common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
+ * <p>The header must be a ByteBuf, while the body can be a ByteBuf, 
InputStream, or ChunkedStream.
+ * Based on 
common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
  */
 public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {
 
@@ -61,8 +63,8 @@ public class EncryptedMessageWithHeader implements 
ChunkedInput<ByteBuf> {
   public EncryptedMessageWithHeader(
       @Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long 
bodyLength) {
     Preconditions.checkArgument(
-        body instanceof InputStream || body instanceof ChunkedStream,
-        "Body must be an InputStream or a ChunkedStream.");
+        body instanceof ByteBuf || body instanceof InputStream || body 
instanceof ChunkedStream,
+        "Body must be a ByteBuf, an InputStream, or a ChunkedStream.");
     this.managedBuffer = managedBuffer;
     this.header = header;
     this.headerLength = header.readableBytes();
@@ -82,40 +84,47 @@ public class EncryptedMessageWithHeader implements 
ChunkedInput<ByteBuf> {
       return null;
     }
 
-    if (totalBytesTransferred < headerLength) {
-      totalBytesTransferred += headerLength;
-      return header.retain();
-    } else if (body instanceof InputStream) {
-      InputStream stream = (InputStream) body;
-      int available = stream.available();
-      if (available <= 0) {
-        available = (int) (length() - totalBytesTransferred);
-      } else {
-        available = (int) Math.min(available, length() - 
totalBytesTransferred);
-      }
-      ByteBuf buffer = allocator.buffer(available);
-      int toRead = Math.min(available, buffer.writableBytes());
-      int read = buffer.writeBytes(stream, toRead);
-      if (read >= 0) {
-        totalBytesTransferred += read;
-        return buffer;
-      } else {
-        throw new EOFException("Unable to read bytes from InputStream");
-      }
-    } else if (body instanceof ChunkedStream) {
-      ChunkedStream stream = (ChunkedStream) body;
-      long old = stream.transferredBytes();
-      ByteBuf buffer = stream.readChunk(allocator);
-      long read = stream.transferredBytes() - old;
-      if (read >= 0) {
-        totalBytesTransferred += read;
-        assert (totalBytesTransferred <= length());
-        return buffer;
+    if (body instanceof ByteBuf) {
+      // For ByteBuf bodies, return header + body as a single composite buffer.
+      ByteBuf bodyBuf = (ByteBuf) body;
+      totalBytesTransferred = headerLength + bodyLength;
+      return Unpooled.wrappedBuffer(header.retain(), bodyBuf.retain());
+    } else {
+      if (totalBytesTransferred < headerLength) {
+        totalBytesTransferred += headerLength;
+        return header.retain();
+      } else if (body instanceof InputStream) {
+        InputStream stream = (InputStream) body;
+        int available = stream.available();
+        if (available <= 0) {
+          available = (int) (length() - totalBytesTransferred);
+        } else {
+          available = (int) Math.min(available, length() - 
totalBytesTransferred);
+        }
+        ByteBuf buffer = allocator.buffer(available);
+        int toRead = Math.min(available, buffer.writableBytes());
+        int read = buffer.writeBytes(stream, toRead);
+        if (read >= 0) {
+          totalBytesTransferred += read;
+          return buffer;
+        } else {
+          throw new EOFException("Unable to read bytes from InputStream");
+        }
+      } else if (body instanceof ChunkedStream) {
+        ChunkedStream stream = (ChunkedStream) body;
+        long old = stream.transferredBytes();
+        ByteBuf buffer = stream.readChunk(allocator);
+        long read = stream.transferredBytes() - old;
+        if (read >= 0) {
+          totalBytesTransferred += read;
+          assert (totalBytesTransferred <= length());
+          return buffer;
+        } else {
+          throw new EOFException("Unable to read bytes from ChunkedStream");
+        }
       } else {
-        throw new EOFException("Unable to read bytes from ChunkedStream");
+        return null;
       }
-    } else {
-      return null;
     }
   }
 
@@ -137,6 +146,7 @@ public class EncryptedMessageWithHeader implements 
ChunkedInput<ByteBuf> {
   @Override
   public void close() throws Exception {
     header.release();
+    ReferenceCountUtil.release(body);
     if (managedBuffer != null) {
       managedBuffer.release();
     }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
index 508b6a13d..cb3f0ed5c 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
@@ -17,15 +17,12 @@
 
 package org.apache.celeborn.common.network.protocol;
 
-import java.io.InputStream;
 import java.util.List;
 
 import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.handler.codec.MessageToMessageEncoder;
-import io.netty.handler.stream.ChunkedStream;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -89,15 +86,9 @@ public final class SslMessageEncoder extends 
MessageToMessageEncoder<Message> {
     assert header.writableBytes() == 0;
 
     if (body != null && bodyLength > 0) {
-      if (body instanceof ByteBuf) {
-        out.add(Unpooled.wrappedBuffer(header, (ByteBuf) body));
-      } else if (body instanceof InputStream || body instanceof ChunkedStream) 
{
-        // For now, assume the InputStream is doing proper chunking.
-        out.add(new EncryptedMessageWithHeader(in.body(), header, body, 
bodyLength));
-      } else {
-        throw new IllegalArgumentException(
-            "Body must be a ByteBuf, ChunkedStream or an InputStream");
-      }
+      // We transfer ownership of the reference on in.body() to 
EncryptedMessageWithHeader.
+      // This reference will be freed when EncryptedMessageWithHeader.close() 
is called.
+      out.add(new EncryptedMessageWithHeader(in.body(), header, body, 
bodyLength));
     } else {
       out.add(header);
     }
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
index 508e9d45d..1b96b307a 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
@@ -19,7 +19,6 @@ package org.apache.celeborn.common.network.protocol;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 import java.io.ByteArrayInputStream;
@@ -146,15 +145,43 @@ public class EncryptedMessageWithHeaderSuiteJ {
     assertEquals(0, header.refCnt());
   }
 
+  // Tests the case where the body is a ByteBuf and that we manage the 
refcounts of the
+  // header, body, and managed buffer properly
   @Test
-  public void testByteBufIsNotSupported() throws Exception {
-    // Validate that ByteBufs are not supported. This test can be updated
-    // when we add support for them
+  public void testByteBufBodyFromManagedBuffer() throws Exception {
+    byte[] randomData = new byte[128];
+    new Random().nextBytes(randomData);
+    ByteBuf sourceBuffer = Unpooled.copiedBuffer(randomData);
+    // convertToNettyForSsl() returns buf.duplicate().retain(), simulate that 
here
+    ByteBuf body = sourceBuffer.duplicate().retain();
     ByteBuf header = Unpooled.copyLong(42);
-    assertThrows(
-        IllegalArgumentException.class,
-        () -> {
-          EncryptedMessageWithHeader msg = new 
EncryptedMessageWithHeader(null, header, header, 4);
-        });
+
+    long expectedHeaderValue = header.getLong(header.readerIndex());
+    assertEquals(1, header.refCnt());
+    assertEquals(2, sourceBuffer.refCnt()); // original + duplicate retain
+    ManagedBuffer managedBuf = new NettyManagedBuffer(sourceBuffer);
+
+    EncryptedMessageWithHeader msg =
+        new EncryptedMessageWithHeader(managedBuf, header, body, 
managedBuf.size());
+    ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
+
+    assertFalse(msg.isEndOfInput());
+
+    // Single read should return header + body as a composite buffer
+    ByteBuf result = msg.readChunk(allocator);
+    assertEquals(header.capacity() + randomData.length, 
result.readableBytes());
+    assertEquals(expectedHeaderValue, result.readLong());
+    for (int i = 0; i < randomData.length; i++) {
+      assertEquals(randomData[i], result.readByte());
+    }
+    assertTrue(msg.isEndOfInput());
+
+    // Release the chunk (simulates Netty writing it out)
+    result.release();
+
+    // Closing the message should release the source buffer via 
managedBuffer.release()
+    msg.close();
+    assertEquals(0, sourceBuffer.refCnt());
+    assertEquals(0, header.refCnt());
   }
 }
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/protocol/SslMessageEncoderSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/protocol/SslMessageEncoderSuiteJ.java
new file mode 100644
index 000000000..0eaffc3ba
--- /dev/null
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/protocol/SslMessageEncoderSuiteJ.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.util.ReferenceCountUtil;
+import org.junit.Test;
+
+import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
+
+/**
+ * Verifies reference counting correctness in SslMessageEncoder.encode() for 
messages with a
+ * NettyManagedBuffer body.
+ *
+ * <p>When convertToNettyForSsl() returns a ByteBuf, the encoder wraps it in an
+ * EncryptedMessageWithHeader whose close() releases the ManagedBuffer. This 
mirrors the non-SSL
+ * MessageEncoder which uses MessageWithHeader.deallocate().
+ */
+public class SslMessageEncoderSuiteJ {
+
+  /**
+   * Core regression test: encoding a PushData with a NettyManagedBuffer body 
must leave the
+   * underlying ByteBuf at refCnt=0 after Netty reads and closes the 
EncryptedMessageWithHeader.
+   */
+  @Test
+  public void testNettyManagedBufferBodyIsReleasedAfterEncoding() throws 
Exception {
+    ByteBuf bodyBuf = Unpooled.copyLong(1L);
+    assertEquals(1, bodyBuf.refCnt());
+
+    PushData pushData =
+        new PushData((byte) 0, "shuffleKey", "partitionId", new 
NettyManagedBuffer(bodyBuf));
+
+    ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+    when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
+
+    List<Object> out = new ArrayList<>();
+    SslMessageEncoder.INSTANCE.encode(ctx, pushData, out);
+
+    assertEquals(1, out.size());
+    assertTrue(out.get(0) instanceof EncryptedMessageWithHeader);
+
+    EncryptedMessageWithHeader msg = (EncryptedMessageWithHeader) out.get(0);
+
+    // convertToNettyForSsl() called retain on a duplicate, so refCnt is 2
+    // (original + duplicate). The ManagedBuffer has not been released yet — 
that
+    // happens when close() is called.
+    assertEquals(2, bodyBuf.refCnt());
+
+    // Simulate Netty's ChunkedWriteHandler: read the chunk, then release it.
+    ByteBuf chunk = msg.readChunk(UnpooledByteBufAllocator.DEFAULT);
+    assertNotNull(chunk);
+    assertTrue(msg.isEndOfInput());
+    ReferenceCountUtil.release(chunk);
+
+    // Simulate Netty closing the ChunkedInput after transfer completes.
+    msg.close();
+
+    // After close(), the ManagedBuffer is released, bringing refCnt to 0.
+    assertEquals(0, bodyBuf.refCnt());
+  }
+}
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/SslClusterReadWriteLeakSuite.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/SslClusterReadWriteLeakSuite.scala
new file mode 100644
index 000000000..8c40e4a37
--- /dev/null
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/SslClusterReadWriteLeakSuite.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.cluster
+
+import java.io.ByteArrayOutputStream
+import java.nio.charset.StandardCharsets
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.JavaConverters._
+
+import io.netty.buffer.UnpooledByteBufAllocator
+import io.netty.util.{ResourceLeakDetector, ResourceLeakDetectorFactory}
+import org.apache.commons.lang3.RandomStringUtils
+import org.junit.Assert
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
+import org.apache.celeborn.client.read.MetricsCallback
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.identity.UserIdentifier
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs
+import org.apache.celeborn.common.protocol.{CompressionCodec, 
TransportModuleConstants}
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+/**
+ * Integration test verifying that the SslMessageEncoder memory-leak fix holds 
under a realistic
+ * mini-cluster workload with TLS enabled on every transport channel.
+ *
+ * <p>The test installs a custom Netty ResourceLeakDetector in PARANOID mode 
before the cluster
+ * starts, runs a full push+replicate+read shuffle cycle with payloads large 
enough to require
+ * multi-record TLS framing (> 16 KB per push), then forces GC and asserts 
that the detector
+ * reported zero leaks.
+ */
+class SslClusterReadWriteLeakSuite
+  extends AnyFunSuite
+  with MiniClusterFeature
+  with BeforeAndAfterAll
+  with Logging {
+
+  private val reportedLeaks = new AtomicInteger(0)
+
+  private var previousLeakLevel: ResourceLeakDetector.Level = _
+  private var previousLeakFactory: ResourceLeakDetectorFactory = _
+  private var testMasterPort: Int = _
+
+  private lazy val serverSslConf: Map[String, String] = {
+    val modules = Seq(
+      TransportModuleConstants.PUSH_MODULE,
+      TransportModuleConstants.REPLICATE_MODULE,
+      TransportModuleConstants.FETCH_MODULE)
+    modules
+      .flatMap(m => 
SslSampleConfigs.createDefaultConfigMapForModule(m).asScala.toSeq)
+      .toMap
+  }
+
+  override def beforeAll(): Unit = {
+
+    // Capture the original leak detection settings so we can restore them in 
afterAll().
+    previousLeakLevel = ResourceLeakDetector.getLevel
+    previousLeakFactory = ResourceLeakDetectorFactory.instance()
+
+    // Install the leak-counting detector BEFORE any Netty buffers are 
allocated so that
+    // AbstractByteBuf.leakDetector (a static final field) is initialized with 
our instance
+    // rather than the default one.
+    installLeakCountingDetector()
+
+    testMasterPort = selectRandomPort()
+    val masterInternalPort = selectRandomPort()
+
+    val masterConf = Map(
+      CelebornConf.MASTER_HOST.key -> "localhost",
+      CelebornConf.PORT_MAX_RETRY.key -> "0",
+      CelebornConf.MASTER_PORT.key -> testMasterPort.toString,
+      CelebornConf.MASTER_ENDPOINTS.key -> s"localhost:$testMasterPort",
+      CelebornConf.MASTER_INTERNAL_PORT.key -> masterInternalPort.toString,
+      CelebornConf.MASTER_INTERNAL_ENDPOINTS.key -> 
s"localhost:$masterInternalPort") ++ serverSslConf
+
+    val workerConf = Map(
+      CelebornConf.MASTER_ENDPOINTS.key -> s"localhost:$testMasterPort",
+      CelebornConf.MASTER_INTERNAL_ENDPOINTS.key -> 
s"localhost:$masterInternalPort") ++ serverSslConf
+
+    setupMiniClusterWithRandomPorts(masterConf, workerConf)
+  }
+
+  override def afterAll(): Unit = {
+    shutdownMiniCluster()
+    // Restore the original leak detection settings so that other suites 
running
+    // in the same JVM (forkMode=once) are not affected.
+    ResourceLeakDetector.setLevel(previousLeakLevel)
+    
ResourceLeakDetectorFactory.setResourceLeakDetectorFactory(previousLeakFactory)
+  }
+
+  // 
---------------------------------------------------------------------------
+
+  test("SSL mini-cluster: push+replicate+fetch large data produces no ByteBuf 
memory leaks") {
+    // Verify that our custom leak-counting detector is the one actually 
installed in
+    // AbstractByteBuf.leakDetector (a static final field). In a shared JVM 
(scalatest
+    // forkMode=once), an earlier suite may have triggered class loading of 
AbstractByteBuf,
+    // causing the default detector to be installed instead of ours. In that 
case, skip
+    // the test rather than silently producing false negatives.
+    val field = 
classOf[io.netty.buffer.AbstractByteBuf].getDeclaredField("leakDetector")
+    field.setAccessible(true)
+    val detector = field.get(null)
+    assume(
+      detector.getClass.getEnclosingClass == 
classOf[SslClusterReadWriteLeakSuite],
+      "Leak-counting detector is not active — AbstractByteBuf.leakDetector was 
" +
+        "initialised by an earlier test in this JVM. Skipping leak 
assertions.")
+
+    val app = "app-ssl-leak-test"
+    val clientConf = buildSslClientConf(app)
+    val lifecycleManager = new LifecycleManager(app, clientConf)
+    val shuffleClient = new ShuffleClientImpl(app, clientConf, 
UserIdentifier("mock", "mock"))
+    shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
+
+    try {
+      // Payloads > 16 KB force TransportFrameDecoder.decodeNext() to assemble 
a
+      // CompositeByteBuf from multiple TLS records – this is to prevent a 
leaked direct ByteBuf
+      // in SslMessageEncoder.encode().
+      val payload32k = RandomStringUtils.random(32 * 
1024).getBytes(StandardCharsets.UTF_8)
+      val payload64k = RandomStringUtils.random(64 * 
1024).getBytes(StandardCharsets.UTF_8)
+      val payloadSmall = 
RandomStringUtils.random(512).getBytes(StandardCharsets.UTF_8)
+
+      // Push via the primary push path (exercises push + replicate channels).
+      shuffleClient.pushData(1, 0, 0, 0, payload32k, 0, payload32k.length, 1, 
1)
+      shuffleClient.pushData(1, 0, 0, 0, payload64k, 0, payload64k.length, 1, 
1)
+
+      // Also exercise mergeData + pushMergedData.
+      shuffleClient.mergeData(1, 0, 0, 0, payload32k, 0, payload32k.length, 1, 
1)
+      shuffleClient.mergeData(1, 0, 0, 0, payloadSmall, 0, 
payloadSmall.length, 1, 1)
+      shuffleClient.pushMergedData(1, 0, 0)
+      Thread.sleep(500)
+
+      shuffleClient.mapperEnd(1, 0, 0, 1, 1)
+
+      // Read back via the fetch channel and verify total byte count.
+      val expectedBytes =
+        payload32k.length + payload64k.length + payload32k.length + 
payloadSmall.length
+
+      val metricsCallback = new MetricsCallback {
+        override def incBytesRead(bytesWritten: Long): Unit = {}
+        override def incReadTime(time: Long): Unit = {}
+      }
+      val inputStream =
+        shuffleClient.readPartition(1, 0, 0, 0L, 0, Integer.MAX_VALUE, 
metricsCallback)
+      try {
+        val output = new ByteArrayOutputStream()
+        var b = inputStream.read()
+        while (b != -1) {
+          output.write(b)
+          b = inputStream.read()
+        }
+
+        Assert.assertEquals(expectedBytes, output.size())
+      } finally {
+        inputStream.close()
+      }
+    } finally {
+      Thread.sleep(2000) // let in-flight replication finish before shutdown
+      shuffleClient.shutdown()
+      lifecycleManager.rpcEnv.shutdown()
+    }
+
+    // Trigger GC and make Netty poll its queue.
+    triggerLeakDetection()
+
+    Assert.assertEquals(0, reportedLeaks.get())
+  }
+
+  /**
+   * Installs a custom ResourceLeakDetectorFactory whose detectors override
+   * reportTracedLeak/reportUntracedLeak to count every leak report in 
reportedLeaks.
+   * Must be called before any ByteBuf is allocated so that 
AbstractByteBuf.leakDetector
+   * (static final) is initialised with our instance.
+   */
+  private def installLeakCountingDetector(): Unit = {
+    ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID)
+    val counter = reportedLeaks
+    ResourceLeakDetectorFactory.setResourceLeakDetectorFactory(
+      new ResourceLeakDetectorFactory() {
+        override def newResourceLeakDetector[T](
+            resource: Class[T],
+            samplingInterval: Int,
+            maxActive: Long): ResourceLeakDetector[T] = {
+          new ResourceLeakDetector[T](resource, samplingInterval) {
+            override protected def reportTracedLeak(
+                resourceType: String,
+                records: String): Unit = {
+              super.reportTracedLeak(resourceType, records)
+              counter.incrementAndGet()
+            }
+            override protected def reportUntracedLeak(resourceType: String): 
Unit = {
+              super.reportUntracedLeak(resourceType)
+              counter.incrementAndGet()
+            }
+          }
+        }
+      })
+  }
+
+  /**
+   * Builds client CelebornConf with SSL enabled on the "data" module, 
matching the production
+   * client-side configuration (spark.celeborn.ssl.data.enabled=true).  
ShuffleClientImpl uses the
+   * DATA_MODULE ("data") for all its data-plane connections (push + fetch) to 
workers.
+   */
+  private def buildSslClientConf(app: String): CelebornConf = {
+    val clientSslConf =
+      SslSampleConfigs.createDefaultConfigMapForModule(
+        TransportModuleConstants.DATA_MODULE).asScala.toMap
+
+    val conf = new CelebornConf()
+      .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$testMasterPort")
+      .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true")
+      .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K")
+      .set("celeborn.data.io.numConnectionsPerPeer", "1")
+      .set(CelebornConf.SHUFFLE_COMPRESSION_CODEC.key, 
CompressionCodec.NONE.name)
+    clientSslConf.foreach { case (k, v) => conf.set(k, v) }
+    conf
+  }
+
+  /**
+   * Forces several rounds of GC and allocates direct buffers in between so 
that Netty's
+   * ResourceLeakDetector (PARANOID mode) polls its PhantomReference queue and 
reports any
+   * buffers that were GC'd without being released.
+   */
+  private def triggerLeakDetection(): Unit = {
+    for (_ <- 1 to 5) {
+      System.gc()
+      System.runFinalization()
+      Thread.sleep(500)
+      // Each directBuffer() allocation causes the detector to drain its ref 
queue.
+      val bufs = (1 to 200).map(_ => 
UnpooledByteBufAllocator.DEFAULT.directBuffer(1))
+      bufs.foreach(_.release())
+      Thread.sleep(200)
+    }
+  }
+}

Reply via email to