This is an automated email from the ASF dual-hosted git repository.

He-Pin pushed a commit to branch pr-2878-clean
in repository https://gitbox.apache.org/repos/asf/pekko.git

commit 000474055f7c7915a93b9f686198653921e23d26
Author: He-Pin <[email protected]>
AuthorDate: Sun Apr 26 00:32:37 2026 +0800

    fix(stream): address TlsGraphStage close-handshake, BUFFER_UNDERFLOW, and 
error-flush gaps
    
    Motivation:
    While exercising TlsGraphStage end-to-end against the full TlsSpec suite
    and adding fragmented-cipher edge cases, three correctness gaps surfaced
    that the original direct-push engine did not handle:
    
    1. EagerClose stalls during the close handshake. The previous pump
       alternated unwrap or wrap per iteration. When both peers reach
       close_notify simultaneously, that gating could leave one direction
       waiting on the other and deadlock the bidirectional close exchange.
    2. Cipher records split across pushes (e.g. one byte at a time) caused
       the engine to return BUFFER_UNDERFLOW indefinitely because the stage
       stopped pulling cipherIn whenever pendingCipherBytes was non-empty.
    3. SSLException thrown from wrap/unwrap failed the user-side outlet but
       never attempted to flush the engine's alert/close_notify to the peer,
       so the remote side never observed the error frame.
    
    Modification:
    - Replace the unwrap-or-wrap if/else in pumpTls with a bidirectional
      step: every iteration tries both directions, so close_notify can
      travel in both directions concurrently. Loop variable renamed to
      `progressed` to make the intent ("did anything change?") explicit.
    - Track a `needMoreCipher` flag set when unwrap returns
      BUFFER_UNDERFLOW. While set, the stage keeps pulling cipherIn even
      if pendingCipherBytes already holds a partial record. cipherIn.onPush
      now accumulates incoming bytes onto the buffer rather than replacing
      it, preserving the partial-record residual.
    - Add `errorFlushing` / `errorFlushTried` state to emulate the legacy
      fail+flush sequence: on SSLException, fail the user-side outlet
      immediately, then attempt one wrap to push the engine's alert frame
      to the peer before completing the transport side.
    
    Result:
    - TlsGraphStageEdgeCasesSpec passes: BUFFER_UNDERFLOW recovery on
      one-byte cipher fragments works without deadlock; EagerClose with
      empty source completes cleanly; backpressured plainOut delivers all
      bytes; per-materialization engine isolation holds.
    - Full TlsGraphStageSpec (111 cases) and TlsGraphStageIsolatedSpec
      (7 cases) pass on JDK 21.
    
    References:
    - PR #2878
    - Issue #2860
---
 .../pekko/stream/impl/io/TlsGraphStage.scala       | 166 +++++++++++++++++----
 1 file changed, 134 insertions(+), 32 deletions(-)

diff --git 
a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala 
b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala
index f539a7dd4a..c0837f64c0 100644
--- a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala
+++ b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala
@@ -92,6 +92,17 @@ import pekko.util.ByteString
       private var inboundClosed = false
       private var outboundCloseRequested = false
       private var unwrapNeedWrapCounter = 0
+      // Set when unwrap returns BUFFER_UNDERFLOW so we KEEP pulling cipherIn 
even though
+      // pendingCipherBytes is nonempty (it holds a partial TLS record that 
the engine
+      // cannot decode without more bytes). Cleared on push from cipherIn. 
Without this,
+      // a TCP chunk that splits a TLS record across multiple pushes deadlocks 
the stage.
+      private var needMoreCipher = false
+      // Error flushing state: emulates legacy `fail(ex, closeTransport=false) 
+ completeOrFlush()`.
+      // When wrap/unwrap throws SSLException, fail the user-side outlet 
immediately so callers
+      // can observe the SSLException, then attempt one wrap to flush the 
engine's alert/close_notify
+      // to the peer before completing the transport side.
+      private var errorFlushing: Throwable = null
+      private var errorFlushTried = false
 
       prepare(userInBuffer)
       prepare(transportInBuffer)
@@ -116,7 +127,12 @@ import pekko.util.ByteString
         cipherIn,
         new InHandler {
           override def onPush(): Unit = {
-            pendingCipherBytes = grab(cipherIn)
+            val incoming = grab(cipherIn)
+            // Accumulate so partial-record residuals from BUFFER_UNDERFLOW 
are preserved.
+            pendingCipherBytes =
+              if (pendingCipherBytes.isEmpty) incoming
+              else pendingCipherBytes ++ incoming
+            needMoreCipher = false
             pumpTls()
           }
 
@@ -174,35 +190,36 @@ import pekko.util.ByteString
         if (isClosed(cipherOut) || cipherOutputClosed) return
 
         var iterations = 0
-        var continue = true
+        var progressed = true
 
-        while (continue && iterations < MaxTlsIterations && 
!isClosed(cipherOut)) {
+        while (progressed && iterations < MaxTlsIterations && 
!isClosed(cipherOut)) {
           iterations += 1
-          continue = false
+          progressed = false
 
-          if (pushPendingOutputs()) continue = true
+          if (pushPendingOutputs()) progressed = true
           if (!warm) {
             tryPullInputs()
             return
           }
 
-          if (materializePendingPlainCommand()) continue = true
-          if (requestOutboundCloseIfNeeded()) continue = true
-          if (closeInboundIfNeeded()) continue = true
-          if (pushPendingOutputs()) continue = true
+          if (materializePendingPlainCommand()) progressed = true
+          if (requestOutboundCloseIfNeeded()) progressed = true
+          if (closeInboundIfNeeded()) progressed = true
+          if (pushPendingOutputs()) progressed = true
 
           if (shouldCompleteStage()) {
             completeStage()
             return
           }
 
-          if (!hasPendingPlainOutput && shouldUnwrap) {
-            continue = doUnwrapStep() || continue
-          } else if (!hasPendingCipherOutput && shouldWrap) {
-            continue = doWrapStep() || continue
-          }
+          // Bidirectional step: attempt BOTH unwrap and wrap each iteration 
so the
+          // TLS close handshake can flow in both directions concurrently 
(peer's
+          // close_notify in, our close_notify out). The previous if/else 
gated one
+          // direction behind the other and could hang on the close exchange.
+          if (!hasPendingPlainOutput && shouldUnwrap && doUnwrapStep()) 
progressed = true
+          if (!hasPendingCipherOutput && shouldWrap && doWrapStep()) 
progressed = true
 
-          if (pushPendingOutputs()) continue = true
+          if (pushPendingOutputs()) progressed = true
           if (shouldCompleteStage()) {
             completeStage()
             return
@@ -217,9 +234,11 @@ import pekko.util.ByteString
       }
 
       private def tryPullInputs(): Unit = {
-        if (!plainInputFinished && pendingPlainCommand == null && 
pendingPlainBytes.isEmpty && !hasBeenPulled(plainIn))
+        if (!plainInputFinished && !isClosed(plainIn) && pendingPlainCommand 
== null && pendingPlainBytes.isEmpty &&
+          !hasBeenPulled(plainIn))
           pull(plainIn)
-        if (!cipherInputFinished && pendingCipherBytes.isEmpty && 
!hasBeenPulled(cipherIn))
+        if (!cipherInputFinished && !isClosed(cipherIn) && 
!hasBeenPulled(cipherIn) &&
+          (pendingCipherBytes.isEmpty || needMoreCipher))
           pull(cipherIn)
       }
 
@@ -236,6 +255,7 @@ import pekko.util.ByteString
         engine != null &&
         !cipherOutputClosed &&
         !engine.isOutboundDone &&
+        !(errorFlushing != null && errorFlushTried) &&
         (lastHandshakeStatus == NEED_WRAP ||
         outboundCloseRequested ||
         (!corkUser && pendingPlainBytes.nonEmpty && lastHandshakeStatus != 
NEED_UNWRAP))
@@ -249,7 +269,9 @@ import pekko.util.ByteString
               pendingPlainCommand = null
               true
 
-            case negotiate: NegotiateNewSession if pendingPlainBytes.isEmpty =>
+            case negotiate: NegotiateNewSession
+                if pendingPlainBytes.isEmpty &&
+                (lastHandshakeStatus == NOT_HANDSHAKING || lastHandshakeStatus 
== FINISHED) =>
               currentSession.invalidate()
               TlsUtils.applySessionParameters(engine, negotiate)
               engine.beginHandshake()
@@ -270,9 +292,14 @@ import pekko.util.ByteString
         val closeForCancellation =
           plainOutputClosed &&
           (!closing.ignoreCancel || plainInputFinished)
+        // Transport gone (peer closed inbound or our cipherIn upstream 
finished): once the engine
+        // sees inbound shutdown, the TLS session cannot continue, so issue 
close_notify and let
+        // shouldCompleteStage tear down. Both `closeInboundIfNeeded` 
(cipherIn finished) and a
+        // direct close_notify from the peer set engine.isInboundDone, which 
we use as the trigger.
+        val closeForTransport = noUserWork && (inboundClosed || 
engine.isInboundDone)
 
         if (engine == null || outboundCloseRequested || engine.isOutboundDone 
||
-          !(closeForCompletion || closeForCancellation))
+          !(closeForCompletion || closeForCancellation || closeForTransport))
           false
         else if (mayCloseOutbound) {
           engine.closeOutbound()
@@ -299,6 +326,7 @@ import pekko.util.ByteString
         !hasPendingCipherOutput &&
         !hasPendingPlainOutput &&
         (cipherOutputClosed ||
+        (errorFlushing != null && errorFlushTried) ||
         (engine != null &&
         engine.isOutboundDone &&
         (engine.isInboundDone || inboundClosed || plainOutputClosed) &&
@@ -325,10 +353,27 @@ import pekko.util.ByteString
       }
 
       private def doWrapStep(): Boolean = {
+        val statusBefore = lastHandshakeStatus
+        val plainBytesBefore = pendingPlainBytes.length
+
         pendingPlainBytes = chopInto(pendingPlainBytes, userInBuffer)
         transportOutBuffer.clear()
 
-        val result = engine.wrap(userInBuffer, transportOutBuffer)
+        val result =
+          try engine.wrap(userInBuffer, transportOutBuffer)
+          catch {
+            case ex: SSLException =>
+              // Engine may have placed an alert record in the buffer; capture 
it so the close
+              // handshake can flush it on a best-effort basis. Then 
transition to error-flushing
+              // mode (legacy `fail(ex, closeTransport=false) + 
completeOrFlush()` equivalent):
+              // user side is failed immediately so callers observe the 
SSLException, while the
+              // transport side stays alive until the alert/close_notify is 
drained.
+              captureCipherOutput()
+              pendingPlainBytes = putBack(userInBuffer, pendingPlainBytes)
+              if (errorFlushing != null) errorFlushTried = true
+              enterErrorFlushingMode(ex)
+              return true
+          }
         lastHandshakeStatus = result.getHandshakeStatus
 
         if (result.getHandshakeStatus == FINISHED) handshakeFinished()
@@ -336,15 +381,22 @@ import pekko.util.ByteString
 
         pendingPlainBytes = putBack(userInBuffer, pendingPlainBytes)
 
+        if (errorFlushing != null) errorFlushTried = true
+
         result.getStatus match {
           case OK =>
-            if (transportOutBuffer.position() == 0 && lastHandshakeStatus == 
NEED_WRAP)
+            val produced = transportOutBuffer.position() > 0
+            if (!produced && lastHandshakeStatus == NEED_WRAP && statusBefore 
== NEED_WRAP) {
               failStage(new IllegalStateException("SSLEngine trying to loop 
NEED_WRAP without producing output"))
-            captureCipherOutput()
-            true
+              false
+            } else {
+              captureCipherOutput()
+              produced || lastHandshakeStatus != statusBefore || 
pendingPlainBytes.length != plainBytesBefore
+            }
 
           case CLOSED =>
             captureCipherOutput()
+            // CLOSED is itself meaningful state progression even with no 
bytes produced.
             true
 
           case status =>
@@ -355,14 +407,30 @@ import pekko.util.ByteString
 
       private def doUnwrapStep(): Boolean = {
         val ignoreOutput = plainOutputClosed
+        val statusBefore = lastHandshakeStatus
         pendingCipherBytes = chopInto(pendingCipherBytes, transportInBuffer)
         val oldInputPosition = transportInBuffer.position()
 
-        val result = engine.unwrap(transportInBuffer, userOutBuffer)
+        val result =
+          try engine.unwrap(transportInBuffer, userOutBuffer)
+          catch {
+            case ex: SSLException =>
+              // Drop any partial output and transition to error-flushing mode 
so the user side
+              // observes the SSLException via plainOut while we attempt to 
flush any alert the
+              // engine produced (cipher side stays alive briefly for that 
flush, then completes).
+              userOutBuffer.clear()
+              pendingCipherBytes = putBack(transportInBuffer, 
pendingCipherBytes)
+              enterErrorFlushingMode(ex)
+              return true
+          }
         if (ignoreOutput) userOutBuffer.clear()
         lastHandshakeStatus = result.getHandshakeStatus
         runDelegatedTasks()
 
+        val newInputPosition = transportInBuffer.position()
+        val producedOutput = userOutBuffer.position() > 0
+        val consumedInput = newInputPosition > oldInputPosition
+
         result.getStatus match {
           case OK =>
             result.getHandshakeStatus match {
@@ -375,7 +443,8 @@ import pekko.util.ByteString
                   false
                 } else {
                   pendingCipherBytes = putBack(transportInBuffer, 
pendingCipherBytes)
-                  true
+                  // Real progress only if the unwrap actually moved engine 
state.
+                  consumedInput || producedOutput || lastHandshakeStatus != 
statusBefore
                 }
 
               case FINISHED =>
@@ -384,33 +453,36 @@ import pekko.util.ByteString
                 pendingCipherBytes = putBack(transportInBuffer, 
pendingCipherBytes)
                 true
 
-              case NEED_UNWRAP
-                  if transportInBuffer.hasRemaining &&
-                  userOutBuffer.position() == 0 &&
-                  transportInBuffer.position() == oldInputPosition =>
+              case NEED_UNWRAP if transportInBuffer.hasRemaining && 
!producedOutput && !consumedInput =>
                 failStage(new IllegalStateException("SSLEngine trying to loop 
NEED_UNWRAP without consuming input"))
                 false
 
               case _ =>
                 capturePlainOutput(currentSession)
                 pendingCipherBytes = putBack(transportInBuffer, 
pendingCipherBytes)
-                true
+                consumedInput || producedOutput || lastHandshakeStatus != 
statusBefore
             }
 
           case CLOSED =>
             capturePlainOutput(currentSession)
             pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes)
+            // CLOSED meaningfully transitions inbound state.
             true
 
           case BUFFER_UNDERFLOW =>
             capturePlainOutput(currentSession)
             pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes)
-            true
+            // Force tryPullInputs to request more cipher even though 
pendingCipherBytes
+            // holds the partial TLS record — we cannot progress without more 
bytes.
+            needMoreCipher = true
+            // No more cipher input available; without more bytes the next 
iteration
+            // can't make further progress on this side.
+            producedOutput
 
           case BUFFER_OVERFLOW =>
             capturePlainOutput(currentSession)
             pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes)
-            true
+            producedOutput || lastHandshakeStatus != statusBefore
 
           case null =>
             failStage(new IllegalStateException("unexpected status null in TLS 
unwrap"))
@@ -448,6 +520,36 @@ import pekko.util.ByteString
           case _                                                          => 
false
         }
 
+      // Legacy `fail(ex, closeTransport=false) + completeOrFlush()` 
equivalent.
+      // Failing plainOut immediately surfaces the SSLException to the caller; 
we keep cipherOut
+      // alive long enough for one wrap pass to push any pending 
alert/close_notify to the peer,
+      // then `shouldCompleteStage` (with the errorFlushTried short-circuit) 
tears the stage down.
+      private def enterErrorFlushingMode(ex: Throwable): Unit = {
+        if (errorFlushing != null) return
+        errorFlushing = ex
+        if (!plainOutputClosed) {
+          pendingPlainOutput = null
+          fail(plainOut, ex)
+          plainOutputClosed = true
+        }
+        if (!plainInputFinished) {
+          if (!isClosed(plainIn)) cancel(plainIn)
+          plainInputFinished = true
+        }
+        pendingPlainBytes = ByteString.empty
+        pendingPlainCommand = null
+        inboundClosed = true
+        if (engine != null && !engine.isOutboundDone && 
!outboundCloseRequested) {
+          try {
+            engine.closeOutbound()
+            outboundCloseRequested = true
+            lastHandshakeStatus = engine.getHandshakeStatus
+          } catch {
+            case NonFatal(_) => ()
+          }
+        }
+      }
+
       private def captureCipherOutput(): Unit = {
         transportOutBuffer.flip()
         if (transportOutBuffer.hasRemaining) pendingCipherOutput = 
ByteString(transportOutBuffer)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to