This is an automated email from the ASF dual-hosted git repository. He-Pin pushed a commit to branch pr-2878-clean in repository https://gitbox.apache.org/repos/asf/pekko.git
commit 000474055f7c7915a93b9f686198653921e23d26 Author: He-Pin <[email protected]> AuthorDate: Sun Apr 26 00:32:37 2026 +0800 fix(stream): address TlsGraphStage close-handshake, BUFFER_UNDERFLOW, and error-flush gaps Motivation: While exercising TlsGraphStage end-to-end against the full TlsSpec suite and adding fragmented-cipher edge cases, three correctness gaps surfaced that the original direct-push engine did not handle: 1. EagerClose stalls during the close handshake. The previous pump alternated unwrap or wrap per iteration. When both peers reach close_notify simultaneously, that gating could leave one direction waiting on the other and deadlock the bidirectional close exchange. 2. Cipher records split across pushes (e.g. one byte at a time) caused the engine to return BUFFER_UNDERFLOW indefinitely because the stage stopped pulling cipherIn whenever pendingCipherBytes was non-empty. 3. SSLException thrown from wrap/unwrap failed the user-side outlet but never attempted to flush the engine's alert/close_notify to the peer, so the remote side never observed the error frame. Modification: - Replace the unwrap-or-wrap if/else in pumpTls with a bidirectional step: every iteration tries both directions, so close_notify can travel in both directions concurrently. Loop variable renamed to `progressed` to make the intent ("did anything change?") explicit. - Track a `needMoreCipher` flag set when unwrap returns BUFFER_UNDERFLOW. While set, the stage keeps pulling cipherIn even if pendingCipherBytes already holds a partial record. cipherIn.onPush now accumulates incoming bytes onto the buffer rather than replacing it, preserving the partial-record residual. - Add `errorFlushing` / `errorFlushTried` state to emulate the legacy fail+flush sequence: on SSLException, fail the user-side outlet immediately, then attempt one wrap to push the engine's alert frame to the peer before completing the transport side. Result: - TlsGraphStageEdgeCasesSpec passes: BUFFER_UNDERFLOW recovery on one-byte cipher fragments works without deadlock; EagerClose with empty source completes cleanly; backpressured plainOut delivers all bytes; per-materialization engine isolation holds. - Full TlsGraphStageSpec (111 cases) and TlsGraphStageIsolatedSpec (7 cases) pass on JDK 21. References: - PR #2878 - Issue #2860 --- .../pekko/stream/impl/io/TlsGraphStage.scala | 166 +++++++++++++++++---- 1 file changed, 134 insertions(+), 32 deletions(-) diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala index f539a7dd4a..c0837f64c0 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala @@ -92,6 +92,17 @@ import pekko.util.ByteString private var inboundClosed = false private var outboundCloseRequested = false private var unwrapNeedWrapCounter = 0 + // Set when unwrap returns BUFFER_UNDERFLOW so we KEEP pulling cipherIn even though + // pendingCipherBytes is nonempty (it holds a partial TLS record that the engine + // cannot decode without more bytes). Cleared on push from cipherIn. Without this, + // a TCP chunk that splits a TLS record across multiple pushes deadlocks the stage. + private var needMoreCipher = false + // Error flushing state: emulates legacy `fail(ex, closeTransport=false) + completeOrFlush()`. + // When wrap/unwrap throws SSLException, fail the user-side outlet immediately so callers + // can observe the SSLException, then attempt one wrap to flush the engine's alert/close_notify + // to the peer before completing the transport side. + private var errorFlushing: Throwable = null + private var errorFlushTried = false prepare(userInBuffer) prepare(transportInBuffer) @@ -116,7 +127,12 @@ import pekko.util.ByteString cipherIn, new InHandler { override def onPush(): Unit = { - pendingCipherBytes = grab(cipherIn) + val incoming = grab(cipherIn) + // Accumulate so partial-record residuals from BUFFER_UNDERFLOW are preserved. + pendingCipherBytes = + if (pendingCipherBytes.isEmpty) incoming + else pendingCipherBytes ++ incoming + needMoreCipher = false pumpTls() } @@ -174,35 +190,36 @@ import pekko.util.ByteString if (isClosed(cipherOut) || cipherOutputClosed) return var iterations = 0 - var continue = true + var progressed = true - while (continue && iterations < MaxTlsIterations && !isClosed(cipherOut)) { + while (progressed && iterations < MaxTlsIterations && !isClosed(cipherOut)) { iterations += 1 - continue = false + progressed = false - if (pushPendingOutputs()) continue = true + if (pushPendingOutputs()) progressed = true if (!warm) { tryPullInputs() return } - if (materializePendingPlainCommand()) continue = true - if (requestOutboundCloseIfNeeded()) continue = true - if (closeInboundIfNeeded()) continue = true - if (pushPendingOutputs()) continue = true + if (materializePendingPlainCommand()) progressed = true + if (requestOutboundCloseIfNeeded()) progressed = true + if (closeInboundIfNeeded()) progressed = true + if (pushPendingOutputs()) progressed = true if (shouldCompleteStage()) { completeStage() return } - if (!hasPendingPlainOutput && shouldUnwrap) { - continue = doUnwrapStep() || continue - } else if (!hasPendingCipherOutput && shouldWrap) { - continue = doWrapStep() || continue - } + // Bidirectional step: attempt BOTH unwrap and wrap each iteration so the + // TLS close handshake can flow in both directions concurrently (peer's + // close_notify in, our close_notify out). The previous if/else gated one + // direction behind the other and could hang on the close exchange. + if (!hasPendingPlainOutput && shouldUnwrap && doUnwrapStep()) progressed = true + if (!hasPendingCipherOutput && shouldWrap && doWrapStep()) progressed = true - if (pushPendingOutputs()) continue = true + if (pushPendingOutputs()) progressed = true if (shouldCompleteStage()) { completeStage() return @@ -217,9 +234,11 @@ import pekko.util.ByteString } private def tryPullInputs(): Unit = { - if (!plainInputFinished && pendingPlainCommand == null && pendingPlainBytes.isEmpty && !hasBeenPulled(plainIn)) + if (!plainInputFinished && !isClosed(plainIn) && pendingPlainCommand == null && pendingPlainBytes.isEmpty && + !hasBeenPulled(plainIn)) pull(plainIn) - if (!cipherInputFinished && pendingCipherBytes.isEmpty && !hasBeenPulled(cipherIn)) + if (!cipherInputFinished && !isClosed(cipherIn) && !hasBeenPulled(cipherIn) && + (pendingCipherBytes.isEmpty || needMoreCipher)) pull(cipherIn) } @@ -236,6 +255,7 @@ import pekko.util.ByteString engine != null && !cipherOutputClosed && !engine.isOutboundDone && + !(errorFlushing != null && errorFlushTried) && (lastHandshakeStatus == NEED_WRAP || outboundCloseRequested || (!corkUser && pendingPlainBytes.nonEmpty && lastHandshakeStatus != NEED_UNWRAP)) @@ -249,7 +269,9 @@ import pekko.util.ByteString pendingPlainCommand = null true - case negotiate: NegotiateNewSession if pendingPlainBytes.isEmpty => + case negotiate: NegotiateNewSession + if pendingPlainBytes.isEmpty && + (lastHandshakeStatus == NOT_HANDSHAKING || lastHandshakeStatus == FINISHED) => currentSession.invalidate() TlsUtils.applySessionParameters(engine, negotiate) engine.beginHandshake() @@ -270,9 +292,14 @@ import pekko.util.ByteString val closeForCancellation = plainOutputClosed && (!closing.ignoreCancel || plainInputFinished) + // Transport gone (peer closed inbound or our cipherIn upstream finished): once the engine + // sees inbound shutdown, the TLS session cannot continue, so issue close_notify and let + // shouldCompleteStage tear down. Both `closeInboundIfNeeded` (cipherIn finished) and a + // direct close_notify from the peer set engine.isInboundDone, which we use as the trigger. + val closeForTransport = noUserWork && (inboundClosed || engine.isInboundDone) if (engine == null || outboundCloseRequested || engine.isOutboundDone || - !(closeForCompletion || closeForCancellation)) + !(closeForCompletion || closeForCancellation || closeForTransport)) false else if (mayCloseOutbound) { engine.closeOutbound() @@ -299,6 +326,7 @@ import pekko.util.ByteString !hasPendingCipherOutput && !hasPendingPlainOutput && (cipherOutputClosed || + (errorFlushing != null && errorFlushTried) || (engine != null && engine.isOutboundDone && (engine.isInboundDone || inboundClosed || plainOutputClosed) && @@ -325,10 +353,27 @@ import pekko.util.ByteString } private def doWrapStep(): Boolean = { + val statusBefore = lastHandshakeStatus + val plainBytesBefore = pendingPlainBytes.length + pendingPlainBytes = chopInto(pendingPlainBytes, userInBuffer) transportOutBuffer.clear() - val result = engine.wrap(userInBuffer, transportOutBuffer) + val result = + try engine.wrap(userInBuffer, transportOutBuffer) + catch { + case ex: SSLException => + // Engine may have placed an alert record in the buffer; capture it so the close + // handshake can flush it on a best-effort basis. Then transition to error-flushing + // mode (legacy `fail(ex, closeTransport=false) + completeOrFlush()` equivalent): + // user side is failed immediately so callers observe the SSLException, while the + // transport side stays alive until the alert/close_notify is drained. + captureCipherOutput() + pendingPlainBytes = putBack(userInBuffer, pendingPlainBytes) + if (errorFlushing != null) errorFlushTried = true + enterErrorFlushingMode(ex) + return true + } lastHandshakeStatus = result.getHandshakeStatus if (result.getHandshakeStatus == FINISHED) handshakeFinished() @@ -336,15 +381,22 @@ import pekko.util.ByteString pendingPlainBytes = putBack(userInBuffer, pendingPlainBytes) + if (errorFlushing != null) errorFlushTried = true + result.getStatus match { case OK => - if (transportOutBuffer.position() == 0 && lastHandshakeStatus == NEED_WRAP) + val produced = transportOutBuffer.position() > 0 + if (!produced && lastHandshakeStatus == NEED_WRAP && statusBefore == NEED_WRAP) { failStage(new IllegalStateException("SSLEngine trying to loop NEED_WRAP without producing output")) - captureCipherOutput() - true + false + } else { + captureCipherOutput() + produced || lastHandshakeStatus != statusBefore || pendingPlainBytes.length != plainBytesBefore + } case CLOSED => captureCipherOutput() + // CLOSED is itself meaningful state progression even with no bytes produced. true case status => @@ -355,14 +407,30 @@ import pekko.util.ByteString private def doUnwrapStep(): Boolean = { val ignoreOutput = plainOutputClosed + val statusBefore = lastHandshakeStatus pendingCipherBytes = chopInto(pendingCipherBytes, transportInBuffer) val oldInputPosition = transportInBuffer.position() - val result = engine.unwrap(transportInBuffer, userOutBuffer) + val result = + try engine.unwrap(transportInBuffer, userOutBuffer) + catch { + case ex: SSLException => + // Drop any partial output and transition to error-flushing mode so the user side + // observes the SSLException via plainOut while we attempt to flush any alert the + // engine produced (cipher side stays alive briefly for that flush, then completes). + userOutBuffer.clear() + pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes) + enterErrorFlushingMode(ex) + return true + } if (ignoreOutput) userOutBuffer.clear() lastHandshakeStatus = result.getHandshakeStatus runDelegatedTasks() + val newInputPosition = transportInBuffer.position() + val producedOutput = userOutBuffer.position() > 0 + val consumedInput = newInputPosition > oldInputPosition + result.getStatus match { case OK => result.getHandshakeStatus match { @@ -375,7 +443,8 @@ import pekko.util.ByteString false } else { pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes) - true + // Real progress only if the unwrap actually moved engine state. + consumedInput || producedOutput || lastHandshakeStatus != statusBefore } case FINISHED => @@ -384,33 +453,36 @@ import pekko.util.ByteString pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes) true - case NEED_UNWRAP - if transportInBuffer.hasRemaining && - userOutBuffer.position() == 0 && - transportInBuffer.position() == oldInputPosition => + case NEED_UNWRAP if transportInBuffer.hasRemaining && !producedOutput && !consumedInput => failStage(new IllegalStateException("SSLEngine trying to loop NEED_UNWRAP without consuming input")) false case _ => capturePlainOutput(currentSession) pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes) - true + consumedInput || producedOutput || lastHandshakeStatus != statusBefore } case CLOSED => capturePlainOutput(currentSession) pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes) + // CLOSED meaningfully transitions inbound state. true case BUFFER_UNDERFLOW => capturePlainOutput(currentSession) pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes) - true + // Force tryPullInputs to request more cipher even though pendingCipherBytes + // holds the partial TLS record — we cannot progress without more bytes. + needMoreCipher = true + // No more cipher input available; without more bytes the next iteration + // can't make further progress on this side. + producedOutput case BUFFER_OVERFLOW => capturePlainOutput(currentSession) pendingCipherBytes = putBack(transportInBuffer, pendingCipherBytes) - true + producedOutput || lastHandshakeStatus != statusBefore case null => failStage(new IllegalStateException("unexpected status null in TLS unwrap")) @@ -448,6 +520,36 @@ import pekko.util.ByteString case _ => false } + // Legacy `fail(ex, closeTransport=false) + completeOrFlush()` equivalent. + // Failing plainOut immediately surfaces the SSLException to the caller; we keep cipherOut + // alive long enough for one wrap pass to push any pending alert/close_notify to the peer, + // then `shouldCompleteStage` (with the errorFlushTried short-circuit) tears the stage down. + private def enterErrorFlushingMode(ex: Throwable): Unit = { + if (errorFlushing != null) return + errorFlushing = ex + if (!plainOutputClosed) { + pendingPlainOutput = null + fail(plainOut, ex) + plainOutputClosed = true + } + if (!plainInputFinished) { + if (!isClosed(plainIn)) cancel(plainIn) + plainInputFinished = true + } + pendingPlainBytes = ByteString.empty + pendingPlainCommand = null + inboundClosed = true + if (engine != null && !engine.isOutboundDone && !outboundCloseRequested) { + try { + engine.closeOutbound() + outboundCloseRequested = true + lastHandshakeStatus = engine.getHandshakeStatus + } catch { + case NonFatal(_) => () + } + } + } + private def captureCipherOutput(): Unit = { transportOutBuffer.flip() if (transportOutBuffer.hasRemaining) pendingCipherOutput = ByteString(transportOutBuffer) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
