andygrove commented on code in PR #1331:
URL: https://github.com/apache/datafusion-comet/pull/1331#discussion_r1929180456


##########
spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala:
##########
@@ -826,723 +826,790 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
   }
 
   /**
-   * Convert a Spark expression to protobuf.
-   *
-   * @param expr
-   *   The input expression
-   * @param inputs
-   *   The input attributes
-   * @param binding
-   *   Whether to bind the expression to the input attributes
-   * @return
-   *   The protobuf representation of the expression, or None if the 
expression is not supported
+   * Wrap an expression in a cast.
    */
-  def exprToProto(
+  def castToProto(
       expr: Expression,
-      input: Seq[Attribute],
-      binding: Boolean = true): Option[Expr] = {
-    def castToProto(
-        timeZoneId: Option[String],
-        dt: DataType,
-        childExpr: Option[Expr],
-        evalMode: CometEvalMode.Value): Option[Expr] = {
-      val dataType = serializeDataType(dt)
-
-      if (childExpr.isDefined && dataType.isDefined) {
+      timeZoneId: Option[String],
+      dt: DataType,
+      childExpr: Expr,
+      evalMode: CometEvalMode.Value): Option[Expr] = {
+    serializeDataType(dt) match {
+      case Some(dataType) =>
         val castBuilder = ExprOuterClass.Cast.newBuilder()
-        castBuilder.setChild(childExpr.get)
-        castBuilder.setDatatype(dataType.get)
+        castBuilder.setChild(childExpr)
+        castBuilder.setDatatype(dataType)
         castBuilder.setEvalMode(evalModeToProto(evalMode))
         
castBuilder.setAllowIncompat(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get())
-        val timeZone = timeZoneId.getOrElse("UTC")
-        castBuilder.setTimezone(timeZone)
-
+        castBuilder.setTimezone(timeZoneId.getOrElse("UTC"))
         Some(
           ExprOuterClass.Expr
             .newBuilder()
             .setCast(castBuilder)
             .build())
-      } else {
-        if (!dataType.isDefined) {
-          withInfo(expr, s"Unsupported datatype ${dt}")
-        } else {
-          withInfo(expr, s"Unsupported expression $childExpr")
-        }
+      case _ =>
+        withInfo(expr, s"Unsupported datatype in castToProto: $dt")
         None
-      }
     }
+  }
 
-    def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): 
Option[Expr] = {
-      SQLConf.get
-
-      def handleCast(
-          child: Expression,
-          inputs: Seq[Attribute],
-          dt: DataType,
-          timeZoneId: Option[String],
-          evalMode: CometEvalMode.Value): Option[Expr] = {
-
-        val childExpr = exprToProtoInternal(child, inputs)
-        if (childExpr.isDefined) {
-          val castSupport =
-            CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode)
-
-          def getIncompatMessage(reason: Option[String]): String =
-            "Comet does not guarantee correct results for cast " +
-              s"from ${child.dataType} to $dt " +
-              s"with timezone $timeZoneId and evalMode $evalMode" +
-              reason.map(str => s" ($str)").getOrElse("")
-
-          castSupport match {
-            case Compatible(_) =>
-              castToProto(timeZoneId, dt, childExpr, evalMode)
-            case Incompatible(reason) =>
-              if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
-                logWarning(getIncompatMessage(reason))
-                castToProto(timeZoneId, dt, childExpr, evalMode)
-              } else {
-                withInfo(
-                  expr,
-                  s"${getIncompatMessage(reason)}. To enable all incompatible 
casts, set " +
-                    s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
-                None
-              }
-            case Unsupported =>
-              withInfo(
-                expr,
-                s"Unsupported cast from ${child.dataType} to $dt " +
-                  s"with timezone $timeZoneId and evalMode $evalMode")
-              None
+  def handleCast(
+      expr: Expression,
+      child: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      dt: DataType,
+      timeZoneId: Option[String],
+      evalMode: CometEvalMode.Value): Option[Expr] = {
+
+    val childExpr = exprToProtoInternal(child, inputs, binding)
+    if (childExpr.isDefined) {
+      val castSupport =
+        CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode)
+
+      def getIncompatMessage(reason: Option[String]): String =
+        "Comet does not guarantee correct results for cast " +
+          s"from ${child.dataType} to $dt " +
+          s"with timezone $timeZoneId and evalMode $evalMode" +
+          reason.map(str => s" ($str)").getOrElse("")
+
+      castSupport match {
+        case Compatible(_) =>
+          castToProto(expr, timeZoneId, dt, childExpr.get, evalMode)
+        case Incompatible(reason) =>
+          if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
+            logWarning(getIncompatMessage(reason))
+            castToProto(expr, timeZoneId, dt, childExpr.get, evalMode)
+          } else {
+            withInfo(
+              expr,
+              s"${getIncompatMessage(reason)}. To enable all incompatible 
casts, set " +
+                s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
+            None
           }
-        } else {
-          withInfo(expr, child)
+        case Unsupported =>
+          withInfo(
+            expr,
+            s"Unsupported cast from ${child.dataType} to $dt " +
+              s"with timezone $timeZoneId and evalMode $evalMode")
           None
-        }
       }
+    } else {
+      withInfo(expr, child)
+      None
+    }
+  }
 
-      expr match {
-        case a @ Alias(_, _) =>
-          val r = exprToProtoInternal(a.child, inputs)
-          if (r.isEmpty) {
-            withInfo(expr, a.child)
-          }
-          r
-
-        case cast @ Cast(_: Literal, dataType, _, _) =>
-          // This can happen after promoting decimal precisions
-          val value = cast.eval()
-          exprToProtoInternal(Literal(value, dataType), inputs)
-
-        case UnaryExpression(child) if expr.prettyName == "trycast" =>
-          val timeZoneId = SQLConf.get.sessionLocalTimeZone
-          handleCast(child, inputs, expr.dataType, Some(timeZoneId), 
CometEvalMode.TRY)
-
-        case c @ Cast(child, dt, timeZoneId, _) =>
-          handleCast(child, inputs, dt, timeZoneId, evalMode(c))
-
-        case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
-          createMathExpression(
-            left,
-            right,
-            inputs,
-            add.dataType,
-            getFailOnError(add),
-            (builder, mathExpr) => builder.setAdd(mathExpr))
-
-        case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
-          withInfo(add, s"Unsupported datatype ${left.dataType}")
-          None
-
-        case sub @ Subtract(left, right, _) if 
supportedDataType(left.dataType) =>
-          createMathExpression(
-            left,
-            right,
-            inputs,
-            sub.dataType,
-            getFailOnError(sub),
-            (builder, mathExpr) => builder.setSubtract(mathExpr))
-
-        case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
-          withInfo(sub, s"Unsupported datatype ${left.dataType}")
-          None
-
-        case mul @ Multiply(left, right, _)
-            if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          createMathExpression(
-            left,
-            right,
-            inputs,
-            mul.dataType,
-            getFailOnError(mul),
-            (builder, mathExpr) => builder.setMultiply(mathExpr))
-
-        case mul @ Multiply(left, _, _) =>
-          if (!supportedDataType(left.dataType)) {
-            withInfo(mul, s"Unsupported datatype ${left.dataType}")
-          }
-          if (decimalBeforeSpark34(left.dataType)) {
-            withInfo(mul, "Decimal support requires Spark 3.4 or later")
-          }
-          None
-
-        case div @ Divide(left, right, _)
-            if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          // Datafusion now throws an exception for dividing by zero
-          // See https://github.com/apache/arrow-datafusion/pull/6792
-          // For now, use NullIf to swap zeros with nulls.
-          val rightExpr = nullIfWhenPrimitive(right)
-
-          createMathExpression(
-            left,
-            rightExpr,
-            inputs,
-            div.dataType,
-            getFailOnError(div),
-            (builder, mathExpr) => builder.setDivide(mathExpr))
-
-        case div @ Divide(left, _, _) =>
-          if (!supportedDataType(left.dataType)) {
-            withInfo(div, s"Unsupported datatype ${left.dataType}")
-          }
-          if (decimalBeforeSpark34(left.dataType)) {
-            withInfo(div, "Decimal support requires Spark 3.4 or later")
-          }
-          None
-
-        case rem @ Remainder(left, right, _)
-            if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          val rightExpr = nullIfWhenPrimitive(right)
-
-          createMathExpression(
-            left,
-            rightExpr,
-            inputs,
-            rem.dataType,
-            getFailOnError(rem),
-            (builder, mathExpr) => builder.setRemainder(mathExpr))
-
-        case rem @ Remainder(left, _, _) =>
-          if (!supportedDataType(left.dataType)) {
-            withInfo(rem, s"Unsupported datatype ${left.dataType}")
-          }
-          if (decimalBeforeSpark34(left.dataType)) {
-            withInfo(rem, "Decimal support requires Spark 3.4 or later")
-          }
-          None
-
-        case EqualTo(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setEq(binaryExpr))
-
-        case Not(EqualTo(left, right)) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setNeq(binaryExpr))
-
-        case EqualNullSafe(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr))
+  /**
+   * Convert a Spark expression to protobuf.
+   *
+   * @param expr
+   *   The input expression
+   * @param inputs
+   *   The input attributes
+   * @param binding
+   *   Whether to bind the expression to the input attributes
+   * @return
+   *   The protobuf representation of the expression, or None if the 
expression is not supported
+   */
+  def exprToProto(
+      expr: Expression,
+      input: Seq[Attribute],

Review Comment:
   @kazuyukitanimura fyi, I fixed the above



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to