This is an automated email from the ASF dual-hosted git repository. He-Pin pushed a commit to branch issue-2860-tls-graphstage-path in repository https://gitbox.apache.org/repos/asf/pekko.git
commit ee588bbcd8dcbeeed8c57991b29606f7b99f1336 Author: He-Pin <[email protected]> AuthorDate: Sun Apr 19 18:40:46 2026 +0800 feat(stream): add GraphStage-backed TLS path behind feature switch (#2860) Motivation: TlsModule currently uses an actor-backed island (TLSActor) for every TLS connection. This makes TLS materialize as a separate actor, adding per-message scheduling overhead and preventing the fused-graph optimiser from crossing the TLS boundary. Issue #2860 tracks replacing the legacy actor path with a proper GraphStage. Modification: - Extract TlsUtils from TLSActor (shared cipher/tracing helpers). - Add TlsGraphStage: a BidiGraphStage that owns the SSLEngine state machine, handles all handshake sequencing, renegotiation gating, close-notify exchange, and error propagation without any internal actor. Key fixes included in the state machine: * shouldCloseOutbound TransferState so a server-role stage can initiate an outbound close even when no user data is pending (prevents deadlock). * After a handshake failure (e.g. certificate_unknown) the first engine.wrap() throws but leaves the engine in NEED_WRAP; a second wrap() call is performed to flush the TLS fatal-alert bytes to the peer, so the peer receives the real error instead of 'closing inbound before receiving peer's close_notify'. - Wire the switch in PhasedFusingActorMaterializer via pekko.stream.materializer.tls.use-legacy-actor (default true, preserving existing behaviour). - Extend TlsSpec to run the full suite against both paths (TlsGraphStageSpec). - Update MaterializerStateSpec to distinguish legacy vs GraphStage actor names. - Add TlsBenchmark in bench-jmh for TLS throughput regression tracking. - Add a runtime-isolation note to the stream-io docs. Result: TlsGraphStageSpec: 111/111 tests pass on both TLSv1.2 and TLSv1.3, including: - normal data transfer - half-close / truncation handling - renegotiation sequencing - certificate-check error propagation (certificate_unknown alert reaches peer) - early-failure / cancellation semantics - hostname verification The legacy TLS actor path is unchanged (default). References: https://github.com/apache/pekko/issues/2860 Co-authored-by: Copilot <[email protected]> --- bench-jmh/src/main/resources/keystore | Bin 0 -> 2397 bytes bench-jmh/src/main/resources/truststore | Bin 0 -> 857 bytes .../org/apache/pekko/stream/TlsBenchmark.scala | 135 ++++ docs/src/main/paradox/stream/stream-io.md | 4 + .../scala/org/apache/pekko/stream/io/TlsSpec.scala | 40 +- .../stream/snapshot/MaterializerStateSpec.scala | 98 ++- stream/src/main/resources/reference.conf | 6 + .../impl/PhasedFusingActorMaterializer.scala | 105 ++- .../org/apache/pekko/stream/impl/io/TLSActor.scala | 33 - .../pekko/stream/impl/io/TlsGraphStage.scala | 769 +++++++++++++++++++++ .../org/apache/pekko/stream/impl/io/TlsUtils.scala | 54 ++ .../org/apache/pekko/stream/scaladsl/TLS.scala | 7 +- 12 files changed, 1184 insertions(+), 67 deletions(-) diff --git a/bench-jmh/src/main/resources/keystore b/bench-jmh/src/main/resources/keystore new file mode 100644 index 0000000000..2b0237562b Binary files /dev/null and b/bench-jmh/src/main/resources/keystore differ diff --git a/bench-jmh/src/main/resources/truststore b/bench-jmh/src/main/resources/truststore new file mode 100644 index 0000000000..3cc1983600 Binary files /dev/null and b/bench-jmh/src/main/resources/truststore differ diff --git a/bench-jmh/src/main/scala/org/apache/pekko/stream/TlsBenchmark.scala b/bench-jmh/src/main/scala/org/apache/pekko/stream/TlsBenchmark.scala new file mode 100644 index 0000000000..ff5226b5fb --- /dev/null +++ b/bench-jmh/src/main/scala/org/apache/pekko/stream/TlsBenchmark.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, which was derived from Akka. + */ + +/* + * Copyright (C) 2018-2022 Lightbend Inc. <https://www.lightbend.com> + */ + +package org.apache.pekko.stream + +import java.security.KeyStore +import java.security.SecureRandom +import java.util.concurrent.TimeUnit +import javax.net.ssl.{ KeyManagerFactory, SSLContext, SSLEngine, TrustManagerFactory } + +import scala.concurrent.Await +import scala.concurrent.duration._ + +import org.openjdk.jmh.annotations._ + +import org.apache.pekko +import pekko.Done +import pekko.actor.ActorSystem +import pekko.stream.scaladsl.{ Flow, Keep, RunnableGraph, Sink, Source, TLS } +import pekko.stream.TLSProtocol.{ SendBytes, SessionBytes, SslTlsInbound, SslTlsOutbound } +import pekko.stream.impl.io.TlsGraphStage +import pekko.util.ByteString + +object TlsBenchmark { + final val OperationsPerInvocation = 256 + private final val Password = "changeme".toCharArray + private final val Protocol = "TLSv1.2" + private final val Payload = ByteString("abcdefgh" * 32) + private final val Ciphers = Array( + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384") + + private def initSslContext(): SSLContext = { + val keyStore = KeyStore.getInstance(KeyStore.getDefaultType) + keyStore.load(getClass.getResourceAsStream("/keystore"), Password) + + val trustStore = KeyStore.getInstance(KeyStore.getDefaultType) + trustStore.load(getClass.getResourceAsStream("/truststore"), Password) + + val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm) + keyManagerFactory.init(keyStore, Password) + + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + trustManagerFactory.init(trustStore) + + val context = SSLContext.getInstance(Protocol) + context.init(keyManagerFactory.getKeyManagers, trustManagerFactory.getTrustManagers, new SecureRandom) + context + } + + private def createSSLEngine(context: SSLContext, role: TLSRole): SSLEngine = { + val engine = context.createSSLEngine() + engine.setUseClientMode(role == Client) + engine.setEnabledProtocols(Array(Protocol)) + engine.setEnabledCipherSuites(Ciphers) + engine + } + + private def withLegacyActorPath[T](enabled: Boolean)(thunk: => T): T = { + val previous = Option(System.getProperty(TlsGraphStage.UseLegacyActorPath)) + System.setProperty(TlsGraphStage.UseLegacyActorPath, if (enabled) "on" else "off") + try thunk + finally + previous match { + case Some(value) => System.setProperty(TlsGraphStage.UseLegacyActorPath, value) + case None => System.clearProperty(TlsGraphStage.UseLegacyActorPath) + } + } +} + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Array(Mode.Throughput)) +class TlsBenchmark { + import TlsBenchmark._ + + private val legacySystem = withLegacyActorPath(true)(ActorSystem("TlsLegacyBenchmark")) + private val graphStageSystem = withLegacyActorPath(false)(ActorSystem("TlsGraphStageBenchmark")) + + private val sslContext = initSslContext() + + private implicit val legacyMaterializer: Materializer = Materializer(legacySystem) + private implicit val graphStageMaterializer: Materializer = Materializer(graphStageSystem) + + private def echoTlsFlow(closing: TLSClosing) = + TLS(() => createSSLEngine(sslContext, Client), closing) + .atop(TLS(() => createSSLEngine(sslContext, Server), closing).reversed) + .join(Flow[SslTlsInbound].collect { case SessionBytes(_, bytes) => SendBytes(bytes) }) + + private val legacyRoundTrip: RunnableGraph[scala.concurrent.Future[Done]] = + withLegacyActorPath(true) { + Source + .repeat[SslTlsOutbound](SendBytes(Payload)) + .take(OperationsPerInvocation.toLong) + .via(echoTlsFlow(EagerClose)) + .toMat(Sink.ignore)(Keep.right) + } + + private val graphStageRoundTrip: RunnableGraph[scala.concurrent.Future[Done]] = + withLegacyActorPath(false) { + Source + .repeat[SslTlsOutbound](SendBytes(Payload)) + .take(OperationsPerInvocation.toLong) + .via(echoTlsFlow(EagerClose)) + .toMat(Sink.ignore)(Keep.right) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def legacyTlsActorPath(): Unit = + Await.result(legacyRoundTrip.run()(legacyMaterializer), 30.seconds) + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def graphStageTlsPath(): Unit = + Await.result(graphStageRoundTrip.run()(graphStageMaterializer), 30.seconds) + + @TearDown + def shutdown(): Unit = { + legacyMaterializer.shutdown() + graphStageMaterializer.shutdown() + Await.result(legacySystem.terminate(), 5.seconds) + Await.result(graphStageSystem.terminate(), 5.seconds) + } +} diff --git a/docs/src/main/paradox/stream/stream-io.md b/docs/src/main/paradox/stream/stream-io.md index b53ef2efea..0f99810396 100644 --- a/docs/src/main/paradox/stream/stream-io.md +++ b/docs/src/main/paradox/stream/stream-io.md @@ -159,6 +159,10 @@ Java The `SSLEngine` instance can then be used with the binding or outgoing connection factory methods. +The TLS stage runs the `SSLEngine` state machine on a dedicated asynchronous stream runtime so the +handshake and wrap/unwrap work stay isolated from the surrounding fused stream and the engine remains +confined to a single stream execution context at a time. + ## Streaming File IO Pekko Streams provide simple Sources and Sinks that can work with @apidoc[util.ByteString] instances to perform IO operations diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsSpec.scala index c018aa48c4..bcca80fd61 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsSpec.scala @@ -30,6 +30,7 @@ import pekko.NotUsed import pekko.pattern.{ after => later } import pekko.stream._ import pekko.stream.TLSProtocol._ +import pekko.stream.impl.io.TlsGraphStage import pekko.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import pekko.stream.scaladsl._ import pekko.stream.stage._ @@ -109,13 +110,39 @@ object TlsSpec { pekko.loggers = ["org.apache.pekko.testkit.SilenceAllTestEventListener"] pekko.actor.debug.receive=off """ + + val graphStageConfigOverrides = + s"""$configOverrides + |pekko.stream.materializer.tls.use-legacy-actor = off + |""".stripMargin + + def setLegacyActorPath(enabled: Boolean): Option[String] = { + val previous = Option(System.getProperty(TlsGraphStage.UseLegacyActorPath)) + System.setProperty(TlsGraphStage.UseLegacyActorPath, if (enabled) "on" else "off") + previous + } + + def restoreLegacyActorPath(previous: Option[String]): Unit = + previous match { + case Some(value) => System.setProperty(TlsGraphStage.UseLegacyActorPath, value) + case None => System.clearProperty(TlsGraphStage.UseLegacyActorPath) + } } -class TlsSpec extends StreamSpec(TlsSpec.configOverrides) with WithLogCapturing { +abstract class TlsSpecBase(configOverrides: String, useLegacyActorPath: Boolean) + extends StreamSpec(configOverrides) + with WithLogCapturing { import GraphDSL.Implicits._ import TlsSpec._ import system.dispatcher + private val previousLegacyActorPath = setLegacyActorPath(useLegacyActorPath) + + override protected def afterTermination(): Unit = { + restoreLegacyActorPath(previousLegacyActorPath) + super.afterTermination() + } + "SslTls" must { "work for TLSv1.2" must { workFor("TLSv1.2", TLS12Ciphers) } @@ -252,6 +279,7 @@ class TlsSpec extends StreamSpec(TlsSpec.configOverrides) with WithLogCapturing case SessionBytes(_, b) => SendBytes(b) } } + def leftClosing: TLSClosing = IgnoreComplete def rightClosing: TLSClosing = IgnoreComplete @@ -605,3 +633,13 @@ class TlsSpec extends StreamSpec(TlsSpec.configOverrides) with WithLogCapturing } } + +/** + * Tests TLS using the legacy TLSActor-based path (default). + */ +class TlsSpec extends TlsSpecBase(TlsSpec.configOverrides, useLegacyActorPath = true) + +/** + * Tests TLS using the new GraphStage-based path. + */ +class TlsGraphStageSpec extends TlsSpecBase(TlsSpec.configOverrides, useLegacyActorPath = false) diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/snapshot/MaterializerStateSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/snapshot/MaterializerStateSpec.scala index d1bcb75823..07d7c165f2 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/snapshot/MaterializerStateSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/snapshot/MaterializerStateSpec.scala @@ -13,18 +13,59 @@ package org.apache.pekko.stream.snapshot -import java.net.InetSocketAddress -import javax.net.ssl.SSLContext +import javax.net.ssl.{ SSLEngine, SSLSession } import scala.concurrent.Promise +import scala.util.Try import org.apache.pekko -import pekko.stream.{ FlowShape, Materializer } -import pekko.stream.scaladsl.{ Flow, GraphDSL, Keep, Merge, Partition, Sink, Source, Tcp } +import pekko.NotUsed +import pekko.actor.ActorSystem +import pekko.stream._ +import pekko.stream.TLSProtocol.{ SendBytes, SessionBytes, SslTlsInbound, SslTlsOutbound } +import pekko.stream.io.TlsSpec +import pekko.stream.scaladsl.{ Flow, GraphDSL, Keep, Merge, Partition, Sink, Source, TLS } import pekko.stream.testkit.scaladsl.TestSink -import pekko.testkit.PekkoSpec +import pekko.testkit.{ PekkoSpec, TestKit } +import pekko.util.ByteString class MaterializerStateSpec extends PekkoSpec() { + private val previousLegacyActorPath = TlsSpec.setLegacyActorPath(true) + + override protected def afterTermination(): Unit = { + TlsSpec.restoreLegacyActorPath(previousLegacyActorPath) + super.afterTermination() + } + + private def localTlsFlow(protocol: String): Flow[SslTlsOutbound, SslTlsInbound, NotUsed] = { + val sslContext = TlsSpec.initSslContext(protocol) + + val ciphers = + if (protocol == "TLSv1.3") TlsSpec.TLS13Ciphers.toArray + else TlsSpec.TLS12Ciphers.toArray + + def createSSLEngine(role: TLSRole): SSLEngine = { + val engine = sslContext.createSSLEngine() + engine.setUseClientMode(role == Client) + engine.setEnabledCipherSuites(ciphers) + engine.setEnabledProtocols(Array(protocol)) + engine + } + + def passthroughTls = + Flow[SslTlsInbound].collect { case SessionBytes(_, bytes) => SendBytes(bytes) } + + TLS(() => createSSLEngine(Client), verifySession = (_: SSLSession) => Try(()), IgnoreComplete) + .atop(TLS(() => createSSLEngine(Server), verifySession = (_: SSLSession) => Try(()), IgnoreComplete).reversed) + .join(passthroughTls) + } + + private def startRunningTlsStream()(implicit mat: Materializer): Unit = + Source + .single[SslTlsOutbound](SendBytes(ByteString("ping"))) + .concat(Source.maybe[SslTlsOutbound]) + .via(localTlsFlow("TLSv1.2")) + .runWith(Sink.ignore) "The MaterializerSnapshotting" must { @@ -59,19 +100,40 @@ class MaterializerStateSpec extends PekkoSpec() { promise.success(1) } - "snapshot a running stream that includes a TLSActor" in { - Source.never - .via(Tcp(system).outgoingConnectionWithTls(InetSocketAddress.createUnresolved("pekko.io", 443), - () => { - val engine = SSLContext.getDefault.createSSLEngine("pekko.io", 443) - engine.setUseClientMode(true) - engine - })) - .runWith(Sink.seq) - - val snapshots = MaterializerState.streamSnapshots(system).futureValue - snapshots.size should be(2) - snapshots.toString should include("TLS-") + "snapshot a running stream that includes the legacy TLS path" in { + implicit val mat = Materializer(system) + try { + startRunningTlsStream() + + awaitAssert({ + val snapshots = MaterializerState.streamSnapshots(mat).futureValue + snapshots should have size 3 + snapshots.count(_.toString.contains("TLS-for-flow-")) should be(2) + }, remainingOrDefault) + } finally { + mat.shutdown() + } + } + + "snapshot a running stream that includes the async GraphStage TLS path" in { + val previousGraphStageSetting = TlsSpec.setLegacyActorPath(false) + val tlsSystem = ActorSystem("MaterializerStateGraphStageTlsSpec", PekkoSpec.testConf) + + try { + implicit val mat = Materializer(tlsSystem) + + startRunningTlsStream() + + awaitAssert({ + val snapshots = MaterializerState.streamSnapshots(mat).futureValue + snapshots should have size 3 + snapshots.count(_.toString.contains("TlsGraphStage")) should be(2) + (snapshots.toString should not).include("TLS-for-flow-") + }, remainingOrDefault) + } finally { + TestKit.shutdownActorSystem(tlsSystem) + TlsSpec.restoreLegacyActorPath(previousGraphStageSetting) + } } "snapshot a stream that has a stopped stage" in { diff --git a/stream/src/main/resources/reference.conf b/stream/src/main/resources/reference.conf index bba71993ec..b6db7e0489 100644 --- a/stream/src/main/resources/reference.conf +++ b/stream/src/main/resources/reference.conf @@ -83,6 +83,12 @@ pekko { # Allows to accelerate message processing that happening within same actor but keep system responsive. sync-processing-limit = 1000 + tls { + # INTERNAL API: rollout switch for the GraphStage-based TLS runtime. + # Keep this enabled until the replacement path has completed validation. + use-legacy-actor = on + } + debug { # Enables the fuzzing mode which increases the chance of race conditions # by aggressively reordering events and making certain operations more diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/PhasedFusingActorMaterializer.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/PhasedFusingActorMaterializer.scala index 96033f26ed..4b58a1d692 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/PhasedFusingActorMaterializer.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/PhasedFusingActorMaterializer.scala @@ -37,6 +37,7 @@ import pekko.dispatch.Dispatchers import pekko.event.Logging import pekko.event.LoggingAdapter import pekko.stream._ +import pekko.stream.TLSProtocol.{ SslTlsInbound, SslTlsOutbound } import pekko.stream.Attributes.InputBuffer import pekko.stream.impl.Stages.DefaultAttributes import pekko.stream.impl.StreamLayout.AtomicModule @@ -45,15 +46,20 @@ import pekko.stream.impl.fusing.ActorGraphInterpreter.ActorOutputBoundary import pekko.stream.impl.fusing.ActorGraphInterpreter.BatchingActorInputBoundary import pekko.stream.impl.fusing.GraphInterpreter.Connection import pekko.stream.impl.io.TLSActor +import pekko.stream.impl.io.TlsGraphStage import pekko.stream.impl.io.TlsModule +import pekko.stream.scaladsl.{ GraphDSL, Sink, Source } +import pekko.stream.scaladsl.RunnableGraph import pekko.stream.stage.GraphStageLogic import pekko.stream.stage.InHandler import pekko.stream.stage.OutHandler +import pekko.util.ByteString import pekko.util.OptionVal import org.reactivestreams.Processor import org.reactivestreams.Publisher import org.reactivestreams.Subscriber +import org.reactivestreams.Subscription /** * INTERNAL API @@ -972,36 +978,109 @@ private[pekko] object GraphStageIsland { extends PhaseIsland[NotUsed] { def name: String = "TlsModulePhase" + private final class InputFailureTrackingSubscriber[T](delegate: Subscriber[T]) extends Subscriber[T] { + @volatile private var failure: Option[Throwable] = None + + def currentFailure: Option[Throwable] = failure + + override def onSubscribe(subscription: Subscription): Unit = delegate.onSubscribe(subscription) + override def onNext(elem: T): Unit = delegate.onNext(elem) + override def onError(cause: Throwable): Unit = { + failure = Some(cause) + delegate.onError(cause) + } + override def onComplete(): Unit = delegate.onComplete() + } + + private var useLegacyActorPath = true + private var tlsModule: TlsModule = _ + private var graphAttributes: Attributes = _ private var tlsActor: ActorRef = _ var publishers: Vector[ActorPublisher[Any]] = _ + private var plainInputProcessor: VirtualProcessor[SslTlsOutbound] = _ + private var cipherInputProcessor: VirtualProcessor[ByteString] = _ + private var cipherOutputProcessor: VirtualProcessor[ByteString] = _ + private var plainOutputProcessor: VirtualProcessor[SslTlsInbound] = _ + private var plainInputTracker: InputFailureTrackingSubscriber[SslTlsOutbound] = _ + private var cipherInputTracker: InputFailureTrackingSubscriber[ByteString] = _ def materializeAtomic(mod: AtomicModule[Shape, Any], attributes: Attributes): (NotUsed, Any) = { val tls = mod.asInstanceOf[TlsModule] + tlsModule = tls + graphAttributes = attributes + useLegacyActorPath = materializer.system.settings.config.getBoolean(TlsGraphStage.UseLegacyActorPath) - val dispatcher = attributes.mandatoryAttribute[ActorAttributes.Dispatcher].dispatcher - val maxInputBuffer = attributes.mandatoryAttribute[Attributes.InputBuffer].max + if (useLegacyActorPath) { + val dispatcher = attributes.mandatoryAttribute[ActorAttributes.Dispatcher].dispatcher + val maxInputBuffer = attributes.mandatoryAttribute[Attributes.InputBuffer].max - val props = - TLSActor.props(maxInputBuffer, tls.createSSLEngine, tls.verifySession, tls.closing) - .withDispatcher(dispatcher) - .withMailbox(PhasedFusingActorMaterializer.MailboxConfigName) + val props = + TLSActor.props(maxInputBuffer, tls.createSSLEngine, tls.verifySession, tls.closing) + .withDispatcher(dispatcher) + .withMailbox(PhasedFusingActorMaterializer.MailboxConfigName) - tlsActor = materializer.actorOf(props, "TLS-for-" + islandName) - def factory(id: Int) = new ActorPublisher[Any](tlsActor) { - override val wakeUpMsg: FanOut.SubstreamSubscribePending = FanOut.SubstreamSubscribePending(id) + tlsActor = materializer.actorOf(props, "TLS-for-" + islandName) + def factory(id: Int) = new ActorPublisher[Any](tlsActor) { + override val wakeUpMsg: FanOut.SubstreamSubscribePending = FanOut.SubstreamSubscribePending(id) + } + publishers = Vector.tabulate(2)(factory) + tlsActor ! FanOut.ExposedPublishers(publishers) + } else { + plainInputProcessor = new VirtualProcessor[SslTlsOutbound] + cipherInputProcessor = new VirtualProcessor[ByteString] + cipherOutputProcessor = new VirtualProcessor[ByteString] + plainOutputProcessor = new VirtualProcessor[SslTlsInbound] + plainInputTracker = new InputFailureTrackingSubscriber[SslTlsOutbound](plainInputProcessor) + cipherInputTracker = new InputFailureTrackingSubscriber[ByteString](cipherInputProcessor) + materializeTlsGraph() } - publishers = Vector.tabulate(2)(factory) - tlsActor ! FanOut.ExposedPublishers(publishers) (NotUsed, NotUsed) } def assignPort(in: InPort, slot: Int, logic: NotUsed): Unit = () def assignPort(out: OutPort, slot: Int, logic: NotUsed): Unit = () def createPublisher(out: OutPort, logic: NotUsed): Publisher[Any] = - publishers(out.id) + if (useLegacyActorPath) publishers(out.id) + else if (out.id == 0) cipherOutputProcessor.asInstanceOf[Publisher[Any]] + else plainOutputProcessor.asInstanceOf[Publisher[Any]] override def takePublisher(slot: Int, publisher: Publisher[Any], attributes: Attributes): Unit = - publisher.subscribe(FanIn.SubInput[Any](tlsActor, 1 - slot)) + if (useLegacyActorPath) + publisher.subscribe(FanIn.SubInput[Any](tlsActor, 1 - slot)) + else { + if (slot == 0) + publisher.subscribe(plainInputTracker.asInstanceOf[Subscriber[Any]]) + else + publisher.subscribe(cipherInputTracker.asInstanceOf[Subscriber[Any]]) + } + + private def materializeTlsGraph(): Unit = { + val tls = tlsModule + val runnable = RunnableGraph.fromGraph(GraphDSL.create() { implicit b => + import GraphDSL.Implicits._ + + val tlsStage = b.add( + new TlsGraphStage( + tls.createSSLEngine, + tls.verifySession, + tls.closing, + () => plainInputTracker.currentFailure, + () => cipherInputTracker.currentFailure)) + val plainSource = b.add(Source.fromPublisher(plainInputProcessor)) + val cipherSource = b.add(Source.fromPublisher(cipherInputProcessor)) + val cipherSink = b.add(Sink.fromSubscriber(cipherOutputProcessor)) + val plainSink = b.add(Sink.fromSubscriber(plainOutputProcessor)) + + plainSource ~> tlsStage.in1 + tlsStage.out1 ~> cipherSink + cipherSource ~> tlsStage.in2 + tlsStage.out2 ~> plainSink + + ClosedShape + }) + + materializer.materialize(runnable, graphAttributes) + } def onIslandReady(): Unit = () } diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TLSActor.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TLSActor.scala index 02cdeda571..feb745cd6e 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TLSActor.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TLSActor.scala @@ -530,36 +530,3 @@ import pekko.util.ByteString context.stop(self) } } - -/** - * INTERNAL API - */ -@InternalApi private[pekko] object TlsUtils { - def applySessionParameters(engine: SSLEngine, sessionParameters: NegotiateNewSession): Unit = { - sessionParameters.enabledCipherSuites.foreach(cs => engine.setEnabledCipherSuites(cs.toArray)) - sessionParameters.enabledProtocols.foreach(p => engine.setEnabledProtocols(p.toArray)) - - sessionParameters.sslParameters.foreach(engine.setSSLParameters) - - sessionParameters.clientAuth match { - case Some(TLSClientAuth.None) => engine.setNeedClientAuth(false) - case Some(TLSClientAuth.Want) => engine.setWantClientAuth(true) - case Some(TLSClientAuth.Need) => engine.setNeedClientAuth(true) - case _ => // do nothing - } - } - - def cloneParameters(old: SSLParameters): SSLParameters = { - val newParameters = new SSLParameters() - newParameters.setAlgorithmConstraints(old.getAlgorithmConstraints) - newParameters.setCipherSuites(old.getCipherSuites) - newParameters.setEndpointIdentificationAlgorithm(old.getEndpointIdentificationAlgorithm) - newParameters.setNeedClientAuth(old.getNeedClientAuth) - newParameters.setProtocols(old.getProtocols) - newParameters.setServerNames(old.getServerNames) - newParameters.setSNIMatchers(old.getSNIMatchers) - newParameters.setUseCipherSuitesOrder(old.getUseCipherSuitesOrder) - newParameters.setWantClientAuth(old.getWantClientAuth) - newParameters - } -} diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala new file mode 100644 index 0000000000..874323724e --- /dev/null +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala @@ -0,0 +1,769 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, which was derived from Akka. + */ + +/* + * Copyright (C) 2015-2022 Lightbend Inc. <https://www.lightbend.com> + */ + +package org.apache.pekko.stream.impl.io + +import java.nio.ByteBuffer +import javax.net.ssl._ +import javax.net.ssl.SSLEngineResult.HandshakeStatus +import javax.net.ssl.SSLEngineResult.HandshakeStatus._ +import javax.net.ssl.SSLEngineResult.Status._ + +import scala.annotation.tailrec +import scala.collection.immutable +import scala.util.{ Failure, Success, Try } +import scala.util.control.NonFatal + +import org.apache.pekko +import pekko.annotation.InternalApi +import pekko.stream._ +import pekko.stream.TLSProtocol._ +import pekko.stream.impl.{ Completed, TransferPhase, TransferState } +import pekko.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } +import pekko.util.ByteString + +/** + * INTERNAL API. + */ +@InternalApi private[stream] object TlsGraphStage { + final val UseLegacyActorPath = "pekko.stream.materializer.tls.use-legacy-actor" + final val DefaultDispatcher = "pekko.stream.materializer.dispatcher" + + def useLegacyActorPath: Boolean = + sys.props.get(UseLegacyActorPath) match { + case Some(value) => + value.trim.toLowerCase match { + case "false" | "off" | "0" | "no" => false + case _ => true + } + case None => true + } +} + +/** + * INTERNAL API. + */ +@InternalApi +private[stream] final class TlsGraphStage( + createSSLEngine: () => SSLEngine, + verifySession: SSLSession => Try[Unit], + closing: TLSClosing, + plainInputFailure: () => Option[Throwable], + cipherInputFailure: () => Option[Throwable]) + extends GraphStage[BidiShape[SslTlsOutbound, ByteString, ByteString, SslTlsInbound]] { + import TlsGraphStage._ + + val plainIn: Inlet[SslTlsOutbound] = Inlet("TlsGraphStage.plainIn") + val plainOut: Outlet[SslTlsInbound] = Outlet("TlsGraphStage.plainOut") + val cipherIn: Inlet[ByteString] = Inlet("TlsGraphStage.cipherIn") + val cipherOut: Outlet[ByteString] = Outlet("TlsGraphStage.cipherOut") + + override val shape: BidiShape[SslTlsOutbound, ByteString, ByteString, SslTlsInbound] = + BidiShape(plainIn, cipherOut, cipherIn, plainOut) + + // The dispatcher attribute doubles as an async boundary, which ensures that the + // TLS state machine is materialized into its own ActorGraphInterpreter actor. + override protected def initialAttributes: Attributes = + Attributes.name("TlsGraphStage") and ActorAttributes.dispatcher(DefaultDispatcher) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) { + + private val maxTLSIterations = 1000 + + private var stopped = false + private var pumping = false + private var pumpAgain = false + private var startupPending = true + private var bridgeFailureCheckActive = true + private var completing = false + + private var unwrapPutBackCounter = 0 + private var lastHandshakeStatus: HandshakeStatus = _ + private var corkUser = true + + private var userInFinished = false + private var transportInFinished = false + private var userOutCancelled = false + private var transportOutCancelled = false + private var userOutTerminated = false + private var transportOutTerminated = false + + private var pendingUserIn: Option[SslTlsOutbound] = None + private var pendingTransportIn: Option[ByteString] = None + + // FIFO buffer for decrypted user output. + // + // During TLS handshake the VirtualProcessor bridges connecting the inner TLS sub-graph + // to the outer stream may not have propagated demand to plainOut yet. Requiring + // isAvailable(plainOut) as a precondition for inbound processing would therefore + // dead-lock the handshake. Instead we decouple inbound processing from plainOut + // demand by buffering pending user messages here. userOutAvailable keeps the pump + // from re-entering inbound processing while buffered output still needs to be drained. + private var pendingUserOut = immutable.Queue.empty[SslTlsInbound] + + // Symmetric FIFO buffer for encrypted transport output. + // + // The inner TLS stage can need to emit handshake bytes (for example the + // initial ClientHello) before the outer VirtualProcessor bridge has + // propagated demand to cipherOut. Buffering one element here avoids that + // startup race and lets the handshake make forward progress. As with + // pendingUserOut, transportOutAvailable keeps the pump from generating more + // transport output while buffered bytes still need to be drained. + private var pendingTransportOut = immutable.Queue.empty[ByteString] + + private val transportOutBuffer = ByteBuffer.allocate(16665 + 2048) + private val userOutBuffer = ByteBuffer.allocate(16665 * 2 + 2048) + private val transportInBuffer = ByteBuffer.allocate(16665 + 2048) + private val userInBuffer = ByteBuffer.allocate(16665 + 2048) + + private var engine: SSLEngine = _ + private var currentSession: SSLSession = _ + + private class ChoppingBlock( + name: String, + pendingElement: => Option[Any], + isDepleted: => Boolean, + refill: Any => Unit) + extends TransferState { + override def isReady: Boolean = buffer.nonEmpty || pendingElement.nonEmpty || isDepleted + override def isCompleted: Boolean = false + + private var buffer = ByteString.empty + + def isEmpty: Boolean = buffer.isEmpty + + def chopInto(b: ByteBuffer): Unit = { + b.compact() + if (buffer.isEmpty) { + pendingElement.foreach { next => + next match { + case bs: ByteString => buffer = bs + case SendBytes(bs) => buffer = bs + case n: NegotiateNewSession => + setNewSessionParameters(n) + buffer = ByteString.empty + case other => + throw new IllegalStateException(s"Unexpected TLS input element [$other] on $name") + } + refill(next) + } + } + + val copied = buffer.copyToBuffer(b) + buffer = buffer.drop(copied) + b.flip() + } + + def putBack(b: ByteBuffer): Unit = + if (b.hasRemaining) { + val bs = ByteString(b) + if (bs.nonEmpty) buffer = bs ++ buffer + prepare(b) + } + + def prepare(b: ByteBuffer): Unit = { + b.clear() + b.limit(0) + } + } + + private val userInChoppingBlock = + new ChoppingBlock( + "UserIn", + pendingUserIn, + isUserInDepleted, + _ => { + pendingUserIn = None + tryPullPlainInIfNeeded() + }) + userInChoppingBlock.prepare(userInBuffer) + + private val transportInChoppingBlock = + new ChoppingBlock( + "TransportIn", + pendingTransportIn, + isTransportInDepleted, + _ => { + pendingTransportIn = None + tryPullCipherInIfNeeded() + }) + transportInChoppingBlock.prepare(transportInBuffer) + + private val transportOutAvailable = new TransferState { + override def isReady: Boolean = !transportOutTerminated && pendingTransportOut.isEmpty + override def isCompleted: Boolean = transportOutTerminated + } + + private val userOutAvailable = new TransferState { + // Ready when no buffered user output is waiting to be drained. + // This decouples inbound processing from whether the downstream has already + // pulled plainOut while still avoiding unbounded re-entry once output has queued up. + override def isReady: Boolean = !userOutTerminated && pendingUserOut.isEmpty + override def isCompleted: Boolean = userOutTerminated + } + + private val engineNeedsWrap = new TransferState { + override def isReady: Boolean = lastHandshakeStatus == NEED_WRAP + override def isCompleted: Boolean = engine.isOutboundDone + } + + private val engineInboundOpen = new TransferState { + override def isReady: Boolean = true + override def isCompleted: Boolean = engine.isInboundDone + } + + // Match TLSActor.awaitingClose semantics: this phase should only run when a + // fresh transport input element is pending. Once transport input is depleted + // it must complete instead of repeatedly unwrapping an empty buffer. + private val transportInPending = new TransferState { + override def isReady: Boolean = pendingTransportIn.nonEmpty + override def isCompleted: Boolean = isTransportInDepleted + } + + private def mayStartRenegotiation: Boolean = + lastHandshakeStatus match { + case HandshakeStatus.NOT_HANDSHAKING | HandshakeStatus.FINISHED => true + case _ => false + } + + private def userInputReadyForWrap: Boolean = + if (!userInChoppingBlock.isEmpty) true + else + pendingUserIn match { + case Some(_: NegotiateNewSession) => mayStartRenegotiation + case Some(_) => true + case None => false + } + + private val userHasData = new TransferState { + override def isReady: Boolean = + !corkUser && userInputReadyForWrap && lastHandshakeStatus != NEED_UNWRAP + override def isCompleted: Boolean = isUserInDepleted + } + + private val userOutCancelledState = new TransferState { + override def isReady: Boolean = userOutCancelled + override def isCompleted: Boolean = engine.isInboundDone || userOutTerminated + } + + private val outbound = (userHasData || engineNeedsWrap) && transportOutAvailable + private val inbound = (transportInChoppingBlock && userOutAvailable) || userOutCancelledState + + private val outboundHalfClosed = engineNeedsWrap && transportOutAvailable + private val inboundHalfClosed = transportInChoppingBlock && engineInboundOpen + + // Detects the case where user input was depleted before or during the TLS handshake. + // + // When the server role completes the handshake by wrapping its own Finished message + // (inside doWrap → doOutbound → bidirectional.action), corkUser is set to false + // and lastHandshakeStatus becomes NOT_HANDSHAKING. After doWrap returns, neither + // outbound (no data to wrap, engine no longer NEED_WRAP) nor inbound (no pending + // transport data) is ready, so the pump loop exits without calling doOutbound again. + // Without this extra predicate, doOutbound would never see isUserInDepleted=true and + // mayCloseOutbound=true in the same pump tick, and the stage would hang waiting for + // an event that never arrives. + // + // This condition is intentionally independent of transportOutAvailable: the decision + // to close the outbound (engine.closeOutbound + nextPhase(outboundClosed)) does not + // write to the transport buffer immediately; the close_notify is sent later via + // doWrap inside the outboundClosed phase. + private val shouldCloseOutbound = new TransferState { + override def isReady: Boolean = + !corkUser && isUserInDepleted && userInChoppingBlock.isEmpty && mayCloseOutbound + override def isCompleted: Boolean = false + } + + private val bidirectional = TransferPhase(outbound || shouldCloseOutbound || inbound) { () => + val continue = doInbound(isOutboundClosed = false, inbound) + if (continue) doOutbound(isInboundClosed = false) + } + + private val flushingOutbound = TransferPhase(outboundHalfClosed) { () => + try doWrap() + catch { case _: SSLException => nextPhase(completedPhase) } + } + + private val awaitingClose = TransferPhase(transportInPending && engineInboundOpen) { () => + transportInChoppingBlock.chopInto(transportInBuffer) + try doUnwrap(ignoreOutput = true) + catch { case _: SSLException => nextPhase(completedPhase) } + } + + private val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { () => + val continue = doInbound(isOutboundClosed = true, inbound) + if (continue && outboundHalfClosed.isReady) { + try doWrap() + catch { case _: SSLException => nextPhase(completedPhase) } + } + } + + private val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { () => + val continue = doInbound(isOutboundClosed = false, inboundHalfClosed) + if (continue) doOutbound(isInboundClosed = true) + } + + private val completedPhase = TransferPhase(Completed) { () => + throw new IllegalStateException("The action of completed phase must never be executed") + } + + private var transferState: TransferState = bidirectional.precondition + private var currentAction: () => Unit = bidirectional.action + private val pumpAsync = getAsyncCallback[Unit] { _ => + startupPending = false + pump() + } + private val drainTransportOutAsync = getAsyncCallback[Unit](_ => drainTransportOut()) + private val drainUserOutAsync = getAsyncCallback[Unit](_ => drainUserOut()) + + setHandler(plainIn, new InHandler { + override def onPush(): Unit = { + pendingUserIn = Some(grab(plainIn)) + pump() + } + + override def onUpstreamFinish(): Unit = { + userInFinished = true + pump() + } + + override def onUpstreamFailure(ex: Throwable): Unit = failTls(ex) + }) + + setHandler(cipherIn, new InHandler { + override def onPush(): Unit = { + pendingTransportIn = Some(grab(cipherIn)) + pump() + } + + override def onUpstreamFinish(): Unit = { + transportInFinished = true + pump() + } + + override def onUpstreamFailure(ex: Throwable): Unit = failTls(ex) + }) + + setHandler(cipherOut, new OutHandler { + override def onPull(): Unit = + if (!startupPending) { + if (!drainTransportOut()) { + if (completing) tryCompleteStage() + else pump() + } + } + + override def onDownstreamFinish(cause: Throwable): Unit = { + transportOutCancelled = true + transportOutTerminated = true + pendingTransportOut = immutable.Queue.empty + pump() + } + }) + + setHandler(plainOut, new OutHandler { + override def onPull(): Unit = { + // Flush any buffered user output before running the pump. + // This drains pendingUserOut (set by enqueueUser when isAvailable(plainOut) was false) + // and re-enables inbound processing (userOutAvailable.isReady = pendingUserOut.isEmpty). + if (!startupPending) { + if (!drainUserOut()) { + if (completing) tryCompleteStage() + else pump() + } + } + } + + override def onDownstreamFinish(cause: Throwable): Unit = { + userOutCancelled = true + pendingUserOut = immutable.Queue.empty + pump() + } + }) + + override def preStart(): Unit = { + try { + engine = createSSLEngine() + engine.beginHandshake() + lastHandshakeStatus = engine.getHandshakeStatus + currentSession = engine.getSession + tryPullPlainInIfNeeded() + tryPullCipherInIfNeeded() + pumpAsync.invoke(()) + } catch { + case NonFatal(ex) => failStage(ex) + } + } + + override def postStop(): Unit = { + stopped = true + super.postStop() + } + + private def nextPhase(phase: TransferPhase): Unit = { + transferState = phase.precondition + currentAction = phase.action + } + + private def isUserInDepleted: Boolean = userInFinished && pendingUserIn.isEmpty + private def isTransportInDepleted: Boolean = transportInFinished && pendingTransportIn.isEmpty + + private def tryPullPlainInIfNeeded(): Unit = + if (pendingUserIn.isEmpty && !isClosed(plainIn) && !hasBeenPulled(plainIn)) pull(plainIn) + + private def tryPullCipherInIfNeeded(): Unit = + if (pendingTransportIn.isEmpty && !isClosed(cipherIn) && !hasBeenPulled(cipherIn)) pull(cipherIn) + + private def completeOrFlush(): Unit = + if (engine.isOutboundDone || (engine.isInboundDone && userInChoppingBlock.isEmpty)) nextPhase(completedPhase) + else nextPhase(flushingOutbound) + + private def pollBridgedInputFailures(): Boolean = + if (bridgeFailureCheckActive) + plainInputFailure().orElse(cipherInputFailure()) match { + case Some(ex) => + failTls(ex) + true + case None => false + } + else false + + private def doInbound(isOutboundClosed: Boolean, inboundState: TransferState): Boolean = + if (isTransportInDepleted && transportInChoppingBlock.isEmpty) { + try engine.closeInbound() + catch { + case ex: SSLException => + if (corkUser) { + failTls(ex) + nextPhase(completedPhase) + return false + } else enqueueUser(SessionTruncated) + } + lastHandshakeStatus = engine.getHandshakeStatus + completeOrFlush() + false + } else if ((inboundState ne inboundHalfClosed) && userOutCancelled) { + if (!isOutboundClosed && closing.ignoreCancel) nextPhase(inboundClosed) + else { + engine.closeOutbound() + lastHandshakeStatus = engine.getHandshakeStatus + nextPhase(flushingOutbound) + } + true + } else if (inboundState.isReady) { + transportInChoppingBlock.chopInto(transportInBuffer) + try { + doUnwrap(ignoreOutput = false) + true + } catch { + case ex: SSLException => + failTls(ex, closeTransport = false) + // After a handshake failure (e.g. certificate_unknown), the SSLEngine buffers + // a TLS fatal alert that must be wrapped and sent to the peer BEFORE we tear + // down the connection. Without flushing it the peer's engine only sees a TCP + // close and throws "closing inbound before receiving peer's close_notify" + // instead of the actual alert. We refresh lastHandshakeStatus here so that + // the engineNeedsWrap predicate correctly reflects the engine's current state. + lastHandshakeStatus = engine.getHandshakeStatus + if (engineNeedsWrap.isReady) { + // Engine has a TLS alert queued. Flush it via flushingOutbound; + // doWrap() will transition to completedPhase once the engine reports CLOSED. + nextPhase(flushingOutbound) + } else { + engine.closeInbound() + completeOrFlush() + } + false + } + } else true + + private def doOutbound(isInboundClosed: Boolean): Unit = + if (isUserInDepleted && userInChoppingBlock.isEmpty && mayCloseOutbound) { + if (!isInboundClosed && closing.ignoreComplete) { + nextPhase(outboundClosed) + } else { + engine.closeOutbound() + lastHandshakeStatus = engine.getHandshakeStatus + nextPhase(outboundClosed) + } + } else if (transportOutCancelled) { + nextPhase(completedPhase) + } else if (outbound.isReady) { + if (userHasData.isReady) userInChoppingBlock.chopInto(userInBuffer) + try doWrap() + catch { + case ex: SSLException => + failTls(ex, closeTransport = false) + // After a handshake failure (e.g. certificate_unknown), the first engine.wrap() + // throws with the cert error but leaves the engine in NEED_WRAP state. + // A second wrap() call produces the actual TLS fatal-alert bytes that must + // reach the peer. Without sending them, the peer only sees a TCP close and + // reports "closing inbound before receiving peer's close_notify" instead of + // the real error. + lastHandshakeStatus = engine.getHandshakeStatus + if (engineNeedsWrap.isReady) { + try doWrap() // flushes the TLS alert into pendingTransportOut + catch { case _: SSLException => } // ignore any secondary exception + } + completeOrFlush() + } + } + + private def mayCloseOutbound: Boolean = + lastHandshakeStatus match { + case HandshakeStatus.NOT_HANDSHAKING | HandshakeStatus.FINISHED => true + case _ => false + } + + private def enqueueTransport(bytes: ByteString): Unit = + if (!transportOutTerminated) { + pendingTransportOut = pendingTransportOut.enqueue(bytes) + drainTransportOutAsync.invoke(()) + } + + private def enqueueUser(message: SslTlsInbound): Unit = + if (!userOutCancelled && !userOutTerminated) { + pendingUserOut = pendingUserOut.enqueue(message) + drainUserOutAsync.invoke(()) + } + + private def drainTransportOut(): Boolean = + if ( + !startupPending && + !transportOutTerminated && + !pollBridgedInputFailures() && + isAvailable(cipherOut) && + pendingTransportOut.nonEmpty) { + val (bytes, remaining) = pendingTransportOut.dequeue + pendingTransportOut = remaining + push(cipherOut, bytes) + true + } else false + + private def drainUserOut(): Boolean = + if ( + !startupPending && + !userOutTerminated && + !pollBridgedInputFailures() && + isAvailable(plainOut) && + pendingUserOut.nonEmpty) { + val (msg, remaining) = pendingUserOut.dequeue + pendingUserOut = remaining + push(plainOut, msg) + true + } else false + + private def flushToTransport(): Unit = { + transportOutBuffer.flip() + if (transportOutBuffer.hasRemaining) enqueueTransport(ByteString(transportOutBuffer)) + transportOutBuffer.clear() + } + + private def flushToUser(): Unit = { + if (unwrapPutBackCounter > 0) unwrapPutBackCounter = 0 + userOutBuffer.flip() + if (userOutBuffer.hasRemaining) enqueueUser(SessionBytes(currentSession, ByteString(userOutBuffer))) + userOutBuffer.clear() + } + + private def doWrap(): Unit = { + val result = engine.wrap(userInBuffer, transportOutBuffer) + lastHandshakeStatus = result.getHandshakeStatus + + if (lastHandshakeStatus == FINISHED) handshakeFinished() + runDelegatedTasks() + + result.getStatus match { + case OK => + if (transportOutBuffer.position() == 0 && lastHandshakeStatus == NEED_WRAP) + throw new IllegalStateException("SSLEngine trying to loop NEED_WRAP without producing output") + + flushToTransport() + userInChoppingBlock.putBack(userInBuffer) + case CLOSED => + flushToTransport() + if (engine.isInboundDone) nextPhase(completedPhase) + else nextPhase(awaitingClose) + case status => + failTls(new IllegalStateException(s"unexpected status $status in doWrap()")) + } + } + + @tailrec + private def doUnwrap(ignoreOutput: Boolean): Unit = { + val oldInPosition = transportInBuffer.position() + val result = engine.unwrap(transportInBuffer, userOutBuffer) + if (ignoreOutput) userOutBuffer.clear() + lastHandshakeStatus = result.getHandshakeStatus + runDelegatedTasks() + + result.getStatus match { + case OK => + result.getHandshakeStatus match { + case NEED_WRAP => + // Keep the legacy guard from TLSActor: SyncProcessingLimit does not + // protect against an SSLEngine loop that stays within a single callback. + unwrapPutBackCounter += 1 + if (unwrapPutBackCounter > maxTLSIterations) { + throw new IllegalStateException( + s"Stuck in unwrap loop, bailing out, last handshake status [$lastHandshakeStatus], " + + s"remaining=${transportInBuffer.remaining}, out=${userOutBuffer.position()}, " + + "(https://github.com/apache/pekko/issues/442)") + } + transportInChoppingBlock.putBack(transportInBuffer) + case FINISHED => + flushToUser() + handshakeFinished() + transportInChoppingBlock.putBack(transportInBuffer) + case NEED_UNWRAP + if transportInBuffer.hasRemaining && + userOutBuffer.position() == 0 && + transportInBuffer.position() == oldInPosition => + throw new IllegalStateException("SSLEngine trying to loop NEED_UNWRAP without producing output") + case _ => + if (transportInBuffer.hasRemaining) doUnwrap(ignoreOutput = false) + else flushToUser() + } + case CLOSED => + flushToUser() + completeOrFlush() + case BUFFER_UNDERFLOW => + flushToUser() + case BUFFER_OVERFLOW => + flushToUser() + transportInChoppingBlock.putBack(transportInBuffer) + case null => + failTls(new IllegalStateException("unexpected status 'null' in doUnwrap()")) + } + } + + @tailrec + private def runDelegatedTasks(): Unit = { + val task = engine.getDelegatedTask + if (task ne null) { + task.run() + runDelegatedTasks() + } else { + lastHandshakeStatus = engine.getHandshakeStatus + } + } + + private def handshakeFinished(): Unit = { + val session = engine.getSession + verifySession(session) match { + case Success(()) => + currentSession = session + corkUser = false + flushToUser() + case Failure(ex) => + failTls(ex, closeTransport = true) + } + } + + private def setNewSessionParameters(params: NegotiateNewSession): Unit = { + currentSession.invalidate() + TlsUtils.applySessionParameters(engine, params) + engine.beginHandshake() + lastHandshakeStatus = engine.getHandshakeStatus + corkUser = true + } + + private def failTls(e: Throwable, closeTransport: Boolean = true): Unit = { + if (!stopped) { + stopped = true + cancelInputs(clearPendingUserOut = true, clearPendingTransportOut = true) + failUserOut(e) + if (closeTransport) failTransportOut(e) + } + } + + private def cancelInputs(clearPendingUserOut: Boolean, clearPendingTransportOut: Boolean): Unit = { + pendingUserIn = None + pendingTransportIn = None + if (clearPendingUserOut) { + pendingUserOut = immutable.Queue.empty + } + if (clearPendingTransportOut) { + pendingTransportOut = immutable.Queue.empty + } + if (!isClosed(plainIn)) cancel(plainIn) + if (!isClosed(cipherIn)) cancel(cipherIn) + } + + private def tryCompleteStage(): Unit = + if ( + completing && + (transportOutTerminated || pendingTransportOut.isEmpty) && + (userOutCancelled || userOutTerminated || pendingUserOut.isEmpty)) { + completeOutputs() + completeStage() + } + + private def failTransportOut(e: Throwable): Unit = + if (!transportOutTerminated) { + transportOutTerminated = true + fail(cipherOut, e) + } + + private def failUserOut(e: Throwable): Unit = + if (!userOutCancelled && !userOutTerminated) { + userOutTerminated = true + fail(plainOut, e) + } + + private def completeOutputs(): Unit = { + if (!transportOutTerminated) { + transportOutTerminated = true + complete(cipherOut) + } + if (!userOutCancelled && !userOutTerminated) { + userOutTerminated = true + complete(plainOut) + } + } + + private def pump(): Unit = { + if (pumping) { + pumpAgain = true + return + } + + pumping = true + try { + do { + pumpAgain = false + if (pollBridgedInputFailures()) return + tryPullPlainInIfNeeded() + tryPullCipherInIfNeeded() + + while (transferState.isExecutable) currentAction() + } while (pumpAgain) + } catch { + case NonFatal(ex) => + failTls(ex) + } finally { + if (bridgeFailureCheckActive) bridgeFailureCheckActive = false + pumping = false + } + + if (transferState.isCompleted) { + if (!completing) { + completing = true + cancelInputs(clearPendingUserOut = false, clearPendingTransportOut = false) + } + if (!drainTransportOut()) drainUserOut() + tryCompleteStage() + } + } + } +} diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsUtils.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsUtils.scala new file mode 100644 index 0000000000..3e47809636 --- /dev/null +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsUtils.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, which was derived from Akka. + */ + +/* + * Copyright (C) 2015-2022 Lightbend Inc. <https://www.lightbend.com> + */ + +package org.apache.pekko.stream.impl.io + +import javax.net.ssl.{ SSLEngine, SSLParameters } + +import org.apache.pekko +import pekko.annotation.InternalApi +import pekko.stream.TLSClientAuth +import pekko.stream.TLSProtocol.NegotiateNewSession + +/** + * INTERNAL API + */ +@InternalApi private[pekko] object TlsUtils { + def applySessionParameters(engine: SSLEngine, sessionParameters: NegotiateNewSession): Unit = { + sessionParameters.enabledCipherSuites.foreach(cs => engine.setEnabledCipherSuites(cs.toArray)) + sessionParameters.enabledProtocols.foreach(p => engine.setEnabledProtocols(p.toArray)) + + sessionParameters.sslParameters.foreach(engine.setSSLParameters) + + sessionParameters.clientAuth match { + case Some(TLSClientAuth.None) => engine.setNeedClientAuth(false) + case Some(TLSClientAuth.Want) => engine.setWantClientAuth(true) + case Some(TLSClientAuth.Need) => engine.setNeedClientAuth(true) + case _ => // do nothing + } + } + + def cloneParameters(old: SSLParameters): SSLParameters = { + val newParameters = new SSLParameters() + newParameters.setAlgorithmConstraints(old.getAlgorithmConstraints) + newParameters.setCipherSuites(old.getCipherSuites) + newParameters.setEndpointIdentificationAlgorithm(old.getEndpointIdentificationAlgorithm) + newParameters.setNeedClientAuth(old.getNeedClientAuth) + newParameters.setProtocols(old.getProtocols) + newParameters.setServerNames(old.getServerNames) + newParameters.setSNIMatchers(old.getSNIMatchers) + newParameters.setUseCipherSuitesOrder(old.getUseCipherSuitesOrder) + newParameters.setWantClientAuth(old.getWantClientAuth) + newParameters + } +} diff --git a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/TLS.scala b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/TLS.scala index 88c2585d3d..c7e021821c 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/TLS.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/TLS.scala @@ -21,7 +21,7 @@ import org.apache.pekko import pekko.NotUsed import pekko.stream._ import pekko.stream.TLSProtocol._ -import pekko.stream.impl.io.TlsModule +import pekko.stream.impl.io.{ TlsGraphStage, TlsModule } import pekko.util.ByteString /** @@ -77,7 +77,10 @@ object TLS { verifySession: SSLSession => Try[Unit], closing: TLSClosing): scaladsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, NotUsed] = scaladsl.BidiFlow.fromGraph( - TlsModule(Attributes.none, () => createSSLEngine(), session => verifySession(session), closing)) + if (TlsGraphStage.useLegacyActorPath) + TlsModule(Attributes.none, () => createSSLEngine(), session => verifySession(session), closing) + else + new TlsGraphStage(() => createSSLEngine(), session => verifySession(session), closing, () => None, () => None)) /** * Create a StreamTls [[pekko.stream.scaladsl.BidiFlow]]. --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
