This is an automated email from the ASF dual-hosted git repository. He-Pin pushed a commit to branch optimize-grpc-unary-fast-paths in repository https://gitbox.apache.org/repos/asf/pekko-grpc.git
commit dc784a7ef9cda6d2c8d005475af5a18fdd1199e0 Author: 虎鸣 <[email protected]> AuthorDate: Mon Apr 27 12:51:22 2026 +0800 Optimize unary gRPC marshalling fast paths --- .../pekko/grpc/GrpcMarshallingBenchmark.scala | 50 ++++++++++++++++-- .../twirl/templates/ScalaServer/Handler.scala.txt | 54 +++++++++++++++---- .../pekko/grpc/javadsl/GrpcMarshalling.scala | 14 ++++- .../pekko/grpc/javadsl/GrpcMarshallingSpec.scala | 61 ++++++++++++++++++++++ 4 files changed, 164 insertions(+), 15 deletions(-) diff --git a/benchmarks/src/main/scala/org/apache/pekko/grpc/GrpcMarshallingBenchmark.scala b/benchmarks/src/main/scala/org/apache/pekko/grpc/GrpcMarshallingBenchmark.scala index c81c9c28..4304cc13 100644 --- a/benchmarks/src/main/scala/org/apache/pekko/grpc/GrpcMarshallingBenchmark.scala +++ b/benchmarks/src/main/scala/org/apache/pekko/grpc/GrpcMarshallingBenchmark.scala @@ -13,17 +13,22 @@ package org.apache.pekko.grpc +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import com.google.protobuf.{ Any => JavaAny, ByteString => JavaByteString } import org.apache.pekko import pekko.actor.ActorSystem -import pekko.grpc.internal.{ GrpcProtocolNative, Identity } +import pekko.grpc.internal.{ AbstractGrpcProtocol, GrpcProtocolNative, Identity } import pekko.grpc.scaladsl.{ GrpcMarshalling, ScalapbProtobufSerializer } -import pekko.http.scaladsl.model.HttpResponse +import pekko.http.scaladsl.model.{ HttpEntity, HttpResponse } +import pekko.stream.SystemMaterializer import pekko.stream.scaladsl.Source import io.grpc.reflection.v1.reflection._ import org.openjdk.jmh.annotations._ // Microbenchmarks for GrpcMarshalling. -// Does not actually benchmarks the actual marshalling because we dont consume the HttpResponse +// Does not actually benchmark response marshalling because we don't consume the HttpResponse. class GrpcMarshallingBenchmark extends CommonBenchmark { implicit val system: ActorSystem = ActorSystem("bench") implicit val writer: GrpcProtocol.GrpcProtocolWriter = GrpcProtocolNative.newWriter(Identity) @@ -31,14 +36,49 @@ class GrpcMarshallingBenchmark extends CommonBenchmark { implicit val serializer: ScalapbProtobufSerializer[ServerReflectionRequest] = ServerReflection.Serializers.ServerReflectionRequestSerializer + val request = ServerReflectionRequest() + val entity: HttpEntity.Strict = + HttpEntity.Strict( + GrpcProtocolNative.contentType, + AbstractGrpcProtocol.encodeFrameData(serializer.serialize(request), isCompressed = false, isTrailer = false)) + + val javaSerializer = new pekko.grpc.javadsl.GoogleProtobufSerializer(JavaAny.parser()) + val javaRequest: JavaAny = + JavaAny.newBuilder().setTypeUrl("benchmark").setValue(JavaByteString.copyFromUtf8("payload")).build() + val javaEntity: pekko.http.javadsl.model.HttpEntity = + HttpEntity.Strict( + GrpcProtocolNative.contentType, + AbstractGrpcProtocol.encodeFrameData(javaSerializer.serialize(javaRequest), isCompressed = false, + isTrailer = false)) + + val mat = SystemMaterializer(system).materializer + @Benchmark def marshall(): HttpResponse = { - GrpcMarshalling.marshal(ServerReflectionRequest()) + GrpcMarshalling.marshal(request) } @Benchmark def marshallStream(): HttpResponse = { - GrpcMarshalling.marshalStream(Source.repeat(ServerReflectionRequest()).take(10000)) + GrpcMarshalling.marshalStream(Source.repeat(request).take(10000)) + } + + @Benchmark + def unmarshallStrict(): ServerReflectionRequest = { + Await.result(GrpcMarshalling.unmarshal(entity), Duration.Inf) + } + + @Benchmark + def unmarshallJavaStrict(): JavaAny = { + pekko.grpc.javadsl.GrpcMarshalling.unmarshal(javaEntity, javaSerializer, mat, reader).toCompletableFuture.get() + } + + @Benchmark + def unmarshallJavaStrictStreamed(): JavaAny = { + pekko.grpc.javadsl.GrpcMarshalling + .unmarshal(javaEntity.getDataBytes, javaSerializer, mat, reader) + .toCompletableFuture + .get() } @TearDown diff --git a/codegen/src/main/twirl/templates/ScalaServer/Handler.scala.txt b/codegen/src/main/twirl/templates/ScalaServer/Handler.scala.txt index 42c46584..4433b2ed 100644 --- a/codegen/src/main/twirl/templates/ScalaServer/Handler.scala.txt +++ b/codegen/src/main/twirl/templates/ScalaServer/Handler.scala.txt @@ -55,7 +55,7 @@ object @{serviceName}Handler { * several services. */ def apply(implementation: @serviceName)(implicit system: ClassicActorSystemProvider): model.HttpRequest => scala.concurrent.Future[model.HttpResponse] = - partial(implementation).orElse { case _ => notFound } + handler(implementation, @{service.name}.name, GrpcExceptionHandler.defaultMapper) /** * Creates a `HttpRequest` to `HttpResponse` handler that can be used in for example `Http().bindAndHandleAsync` @@ -65,7 +65,7 @@ object @{serviceName}Handler { * several services. */ def apply(implementation: @serviceName, eHandler: ActorSystem => PartialFunction[Throwable, Trailers])(implicit system: ClassicActorSystemProvider): model.HttpRequest => scala.concurrent.Future[model.HttpResponse] = - partial(implementation, @{service.name}.name, eHandler).orElse { case _ => notFound } + handler(implementation, @{service.name}.name, eHandler) /** * Creates a `HttpRequest` to `HttpResponse` handler that can be used in for example `Http().bindAndHandleAsync` @@ -77,7 +77,7 @@ object @{serviceName}Handler { * Registering a gRPC service under a custom prefix is not widely supported and strongly discouraged by the specification. */ def apply(implementation: @serviceName, prefix: String)(implicit system: ClassicActorSystemProvider): model.HttpRequest => scala.concurrent.Future[model.HttpResponse] = - partial(implementation, prefix).orElse { case _ => notFound } + handler(implementation, prefix, GrpcExceptionHandler.defaultMapper) /** * Creates a `HttpRequest` to `HttpResponse` handler that can be used in for example `Http().bindAndHandleAsync` @@ -89,7 +89,7 @@ object @{serviceName}Handler { * Registering a gRPC service under a custom prefix is not widely supported and strongly discouraged by the specification. */ def apply(implementation: @serviceName, prefix: String, eHandler: ActorSystem => PartialFunction[Throwable, Trailers])(implicit system: ClassicActorSystemProvider): model.HttpRequest => scala.concurrent.Future[model.HttpResponse] = - partial(implementation, prefix, eHandler).orElse { case _ => notFound } + handler(implementation, prefix, eHandler) @if(serviceName != "ServerReflection") { @@ -107,6 +107,43 @@ object @{serviceName}Handler { pekko.grpc.scaladsl.ServerReflection.partial(List(@{service.name}))) } + private def methodName(request: model.HttpRequest, prefix: String): String = + request.uri.path match { + case model.Uri.Path.Slash(model.Uri.Path.Segment(`prefix`, model.Uri.Path.Slash(model.Uri.Path.Segment(method, model.Uri.Path.Empty)))) => + method + case _ => + null + } + + private def handler(implementation: @serviceName, prefix: String, eHandler: ActorSystem => PartialFunction[Throwable, Trailers])(implicit system: ClassicActorSystemProvider): model.HttpRequest => scala.concurrent.Future[model.HttpResponse] = { + implicit val mat: Materializer = SystemMaterializer(system).materializer + implicit val ec: ExecutionContext = mat.executionContext + val spi = TelemetryExtension(system).spi + + import @{service.name}.Serializers.@{service.scalaCompatConstants.WildcardImport} + + def handle(request: model.HttpRequest, method: String): scala.concurrent.Future[model.HttpResponse] = + GrpcMarshalling.negotiated(request, (reader, writer) => + (method match { + @for(method <- service.methods) { + case "@method.grpcName" => + @{if(powerApis) { "val metadata = MetadataBuilder.fromHeaders(request.headers)" } else { "" }} + @{method.unmarshal}(request.entity)(@{service.scalaCompatConstants.ImplicitUsing}@method.deserializer.name, mat, reader) + .@{if(method.outputStreaming) { "map" } else { "flatMap" }}(implementation.@{method.nameSafe}(_@{if(powerApis) { ", metadata" } else { "" }})) + .map(e => @{method.marshal}(e, eHandler)(@{service.scalaCompatConstants.ImplicitUsing}@method.serializer.name, writer, system)) + } + case m => scala.concurrent.Future.failed(new NotImplementedError(s"Not implemented: $m")) + }) + .recoverWith(GrpcExceptionHandler.from(eHandler(system.classicSystem))(system, writer)) + ).getOrElse(unsupportedMediaType) + + request => { + val method = methodName(request, prefix) + if (method eq null) notFound + else handle(spi.onRequest(prefix, method, request), method) + } + } + /** * Creates a partial `HttpRequest` to `HttpResponse` handler that can be combined with handlers of other * services with `org.apache.pekko.grpc.scaladsl.ServiceHandler.concatOrNotFound` and then used in for example @@ -138,11 +175,10 @@ object @{serviceName}Handler { .recoverWith(GrpcExceptionHandler.from(eHandler(system.classicSystem))(system, writer)) ).getOrElse(unsupportedMediaType) - Function.unlift((req: model.HttpRequest) => req.uri.path match { - case model.Uri.Path.Slash(model.Uri.Path.Segment(`prefix`, model.Uri.Path.Slash(model.Uri.Path.Segment(method, model.Uri.Path.Empty)))) => - Some(handle(spi.onRequest(prefix, method, req), method)) - case _ => - None + Function.unlift((req: model.HttpRequest) => { + val method = methodName(req, prefix) + if (method eq null) None + else Some(handle(spi.onRequest(prefix, method, req), method)) }) } } diff --git a/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcMarshalling.scala b/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcMarshalling.scala index 07fdee0f..5981ad13 100644 --- a/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcMarshalling.scala +++ b/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcMarshalling.scala @@ -30,6 +30,7 @@ import pekko.stream.javadsl.Source import pekko.util.ByteString import scala.annotation.nowarn +import scala.util.control.NonFatal object GrpcMarshalling { @@ -56,7 +57,12 @@ object GrpcMarshalling { u: ProtobufSerializer[T], mat: Materializer, reader: GrpcProtocolReader): CompletionStage[T] = - unmarshal(entity.getDataBytes, u, mat, reader) + entity match { + case strict: pekko.http.scaladsl.model.HttpEntity.Strict => + completedOrFailed(u.deserialize(reader.decodeSingleFrame(strict.data))) + case _ => + unmarshal(entity.getDataBytes, u, mat, reader) + } def unmarshalStream[T]( data: Source[ByteString, AnyRef], @@ -103,4 +109,10 @@ object GrpcMarshalling { future.completeExceptionally(error) future } + + private def completedOrFailed[R](value: => R): CompletionStage[R] = + try CompletableFuture.completedFuture(value) + catch { + case NonFatal(error) => failure(error) + } } diff --git a/runtime/src/test/scala/org/apache/pekko/grpc/javadsl/GrpcMarshallingSpec.scala b/runtime/src/test/scala/org/apache/pekko/grpc/javadsl/GrpcMarshallingSpec.scala new file mode 100644 index 00000000..592834de --- /dev/null +++ b/runtime/src/test/scala/org/apache/pekko/grpc/javadsl/GrpcMarshallingSpec.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.grpc.javadsl + +import java.util.concurrent.TimeUnit + +import scala.concurrent.Await +import scala.concurrent.duration._ + +import com.google.protobuf.{ Any => ProtobufAny, ByteString => ProtobufByteString } +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import org.apache.pekko +import pekko.actor.ActorSystem +import pekko.grpc.internal.{ AbstractGrpcProtocol, GrpcProtocolNative, Identity } +import pekko.http.scaladsl.model.HttpEntity +import pekko.stream.SystemMaterializer + +class GrpcMarshallingSpec extends AnyWordSpec with Matchers { + "The javadsl GrpcMarshalling" should { + "unmarshal a strict unary entity" in { + val system = ActorSystem("GrpcMarshallingSpec") + try { + val mat = SystemMaterializer(system).materializer + val serializer = new GoogleProtobufSerializer(ProtobufAny.parser()) + val message = + ProtobufAny.newBuilder().setTypeUrl("benchmark").setValue(ProtobufByteString.copyFromUtf8("payload")).build() + val entity = + HttpEntity.Strict( + GrpcProtocolNative.contentType, + AbstractGrpcProtocol.encodeFrameData(serializer.serialize(message), isCompressed = false, isTrailer = false)) + + val result = + GrpcMarshalling + .unmarshal(entity, serializer, mat, GrpcProtocolNative.newReader(Identity)) + .toCompletableFuture + .get(10, TimeUnit.SECONDS) + + result should be(message) + } finally { + Await.result(system.terminate(), 10.seconds) + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
