viirya commented on code in PR #1331:
URL: https://github.com/apache/datafusion-comet/pull/1331#discussion_r1927454327


##########
spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala:
##########
@@ -1557,918 +1624,935 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 //            None
 //          }
 
-        case Acos(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("acos", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Asin(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("asin", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Atan(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("atan", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Atan2(left, right) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-          val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, left, right)
-
-        case Hex(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr =
-            scalarExprToProtoWithReturnType("hex", StringType, childExpr)
-
-          optExprWithInfo(optExpr, expr, child)
-
-        case e: Unhex =>
-          val unHex = unhexSerde(e)
-
-          val childExpr = exprToProtoInternal(unHex._1, inputs)
-          val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs)
-
-          val optExpr =
-            scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, 
failOnErrorExpr)
-          optExprWithInfo(optExpr, expr, unHex._1)
-
-        case e @ Ceil(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          child.dataType match {
-            case t: DecimalType if t.scale == 0 => // zero scale is no-op
-              childExpr
-            case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
-              withInfo(e, s"Decimal type $t has negative scale")
-              None
-            case _ =>
-              val optExpr = scalarExprToProtoWithReturnType("ceil", 
e.dataType, childExpr)
-              optExprWithInfo(optExpr, expr, child)
-          }
-
-        case Cos(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("cos", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Exp(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("exp", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case e @ Floor(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          child.dataType match {
-            case t: DecimalType if t.scale == 0 => // zero scale is no-op
-              childExpr
-            case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
-              withInfo(e, s"Decimal type $t has negative scale")
-              None
-            case _ =>
-              val optExpr = scalarExprToProtoWithReturnType("floor", 
e.dataType, childExpr)
-              optExprWithInfo(optExpr, expr, child)
-          }
-
-        // The expression for `log` functions is defined as null on numbers 
less than or equal
-        // to 0. This matches Spark and Hive behavior, where non positive 
values eval to null
-        // instead of NaN or -Infinity.
-        case Log(child) =>
-          val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
-          val optExpr = scalarExprToProto("ln", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Log10(child) =>
-          val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
-          val optExpr = scalarExprToProto("log10", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Log2(child) =>
-          val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
-          val optExpr = scalarExprToProto("log2", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Pow(left, right) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-          val optExpr = scalarExprToProto("pow", leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, left, right)
-
-        case r: Round =>
-          // _scale s a constant, copied from Spark's RoundBase because it is 
a protected val
-          val scaleV: Any = r.scale.eval(EmptyRow)
-          val _scale: Int = scaleV.asInstanceOf[Int]
-
-          lazy val childExpr = exprToProtoInternal(r.child, inputs)
-          r.child.dataType match {
-            case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
-              withInfo(r, "Decimal type has negative scale")
-              None
-            case _ if scaleV == null =>
-              exprToProtoInternal(Literal(null), inputs)
-            case _: ByteType | ShortType | IntegerType | LongType if _scale >= 
0 =>
-              childExpr // _scale(I.e. decimal place) >= 0 is a no-op for 
integer types in Spark
-            case _: FloatType | DoubleType =>
-              // We cannot properly match with the Spark behavior for 
floating-point numbers.
-              // Spark uses BigDecimal for rounding float/double, and 
BigDecimal fist converts a
-              // double to string internally in order to create its own 
internal representation.
-              // The problem is BigDecimal uses java.lang.Double.toString() 
and it has complicated
-              // rounding algorithm. E.g. -5.81855622136895E8 is actually
-              // -581855622.13689494132995605468750. Note the 5th fractional 
digit is 4 instead of
-              // 5. Java(Scala)'s toString() rounds it up to 
-581855622.136895. This makes a
-              // difference when rounding at 5th digit, I.e. 
round(-5.81855622136895E8, 5) should be
-              // -5.818556221369E8, instead of -5.8185562213689E8. There is 
also an example that
-              // toString() does NOT round up. 6.1317116247283497E18 is 
6131711624728349696. It can
-              // be rounded up to 6.13171162472835E18 that still represents 
the same double number.
-              // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, 
toString() does not.
-              // That results in round(6.1317116247283497E18, -5) == 
6.1317116247282995E18 instead
-              // of 6.1317116247283999E18.
-              withInfo(r, "Comet does not support Spark's BigDecimal rounding")
-              None
-            case _ =>
-              // `scale` must be Int64 type in DataFusion
-              val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, 
LongType), inputs)
-              val optExpr =
-                scalarExprToProtoWithReturnType("round", r.dataType, 
childExpr, scaleExpr)
-              optExprWithInfo(optExpr, expr, r.child)
-          }
+      case Acos(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("acos", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Asin(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("asin", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Atan(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("atan", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Atan2(left, right) =>
+        val leftExpr = exprToProtoInternal(left, input, binding)
+        val rightExpr = exprToProtoInternal(right, input, binding)
+        val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, left, right)
+
+      case Hex(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr =
+          scalarExprToProtoWithReturnType("hex", StringType, childExpr)
+
+        optExprWithInfo(optExpr, expr, child)
+
+      case e: Unhex =>
+        val unHex = unhexSerde(e)
+
+        val childExpr = exprToProtoInternal(unHex._1, input, binding)
+        val failOnErrorExpr = exprToProtoInternal(unHex._2, input, binding)
+
+        val optExpr =
+          scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, 
failOnErrorExpr)
+        optExprWithInfo(optExpr, expr, unHex._1)
+
+      case e @ Ceil(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        child.dataType match {
+          case t: DecimalType if t.scale == 0 => // zero scale is no-op
+            childExpr
+          case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
+            withInfo(e, s"Decimal type $t has negative scale")
+            None
+          case _ =>
+            val optExpr = scalarExprToProtoWithReturnType("ceil", e.dataType, 
childExpr)
+            optExprWithInfo(optExpr, expr, child)
+        }
+
+      case Cos(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("cos", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Exp(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("exp", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case e @ Floor(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        child.dataType match {
+          case t: DecimalType if t.scale == 0 => // zero scale is no-op
+            childExpr
+          case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
+            withInfo(e, s"Decimal type $t has negative scale")
+            None
+          case _ =>
+            val optExpr = scalarExprToProtoWithReturnType("floor", e.dataType, 
childExpr)
+            optExprWithInfo(optExpr, expr, child)
+        }
+
+      // The expression for `log` functions is defined as null on numbers less 
than or equal
+      // to 0. This matches Spark and Hive behavior, where non positive values 
eval to null
+      // instead of NaN or -Infinity.
+      case Log(child) =>
+        val childExpr = exprToProtoInternal(nullIfNegative(child), input, 
binding)
+        val optExpr = scalarExprToProto("ln", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Log10(child) =>
+        val childExpr = exprToProtoInternal(nullIfNegative(child), input, 
binding)
+        val optExpr = scalarExprToProto("log10", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Log2(child) =>
+        val childExpr = exprToProtoInternal(nullIfNegative(child), input, 
binding)
+        val optExpr = scalarExprToProto("log2", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Pow(left, right) =>
+        val leftExpr = exprToProtoInternal(left, input, binding)
+        val rightExpr = exprToProtoInternal(right, input, binding)
+        val optExpr = scalarExprToProto("pow", leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, left, right)
+
+      case r: Round =>
+        // _scale s a constant, copied from Spark's RoundBase because it is a 
protected val
+        val scaleV: Any = r.scale.eval(EmptyRow)
+        val _scale: Int = scaleV.asInstanceOf[Int]
+
+        lazy val childExpr = exprToProtoInternal(r.child, input, binding)
+        r.child.dataType match {
+          case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
+            withInfo(r, "Decimal type has negative scale")
+            None
+          case _ if scaleV == null =>
+            exprToProtoInternal(Literal(null), input, binding)
+          case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 
=>
+            childExpr // _scale(I.e. decimal place) >= 0 is a no-op for 
integer types in Spark
+          case _: FloatType | DoubleType =>
+            // We cannot properly match with the Spark behavior for 
floating-point numbers.
+            // Spark uses BigDecimal for rounding float/double, and BigDecimal 
fist converts a
+            // double to string internally in order to create its own internal 
representation.
+            // The problem is BigDecimal uses java.lang.Double.toString() and 
it has complicated
+            // rounding algorithm. E.g. -5.81855622136895E8 is actually
+            // -581855622.13689494132995605468750. Note the 5th fractional 
digit is 4 instead of
+            // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. 
This makes a
+            // difference when rounding at 5th digit, I.e. 
round(-5.81855622136895E8, 5) should be
+            // -5.818556221369E8, instead of -5.8185562213689E8. There is also 
an example that
+            // toString() does NOT round up. 6.1317116247283497E18 is 
6131711624728349696. It can
+            // be rounded up to 6.13171162472835E18 that still represents the 
same double number.
+            // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, 
toString() does not.
+            // That results in round(6.1317116247283497E18, -5) == 
6.1317116247282995E18 instead
+            // of 6.1317116247283999E18.
+            withInfo(r, "Comet does not support Spark's BigDecimal rounding")
+            None
+          case _ =>
+            // `scale` must be Int64 type in DataFusion
+            val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, 
LongType), input, binding)
+            val optExpr =
+              scalarExprToProtoWithReturnType("round", r.dataType, childExpr, 
scaleExpr)
+            optExprWithInfo(optExpr, expr, r.child)
+        }
 
-        // TODO enable once https://github.com/apache/datafusion/issues/11557 
is fixed or
-        // when we have a Spark-compatible version implemented in Comet
+      // TODO enable once https://github.com/apache/datafusion/issues/11557 is 
fixed or
+      // when we have a Spark-compatible version implemented in Comet
 //        case Signum(child) =>
 //          val childExpr = exprToProtoInternal(child, inputs)
 //          val optExpr = scalarExprToProto("signum", childExpr)
 //          optExprWithInfo(optExpr, expr, child)
 
-        case Sin(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("sin", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Sqrt(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("sqrt", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Tan(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("tan", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Ascii(child) =>
-          val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("ascii", childExpr)
-          optExprWithInfo(optExpr, expr, castExpr)
-
-        case BitLength(child) =>
-          val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("bit_length", childExpr)
-          optExprWithInfo(optExpr, expr, castExpr)
-
-        case If(predicate, trueValue, falseValue) =>
-          val predicateExpr = exprToProtoInternal(predicate, inputs)
-          val trueExpr = exprToProtoInternal(trueValue, inputs)
-          val falseExpr = exprToProtoInternal(falseValue, inputs)
-          if (predicateExpr.isDefined && trueExpr.isDefined && 
falseExpr.isDefined) {
-            val builder = ExprOuterClass.IfExpr.newBuilder()
-            builder.setIfExpr(predicateExpr.get)
-            builder.setTrueExpr(trueExpr.get)
-            builder.setFalseExpr(falseExpr.get)
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setIf(builder)
-                .build())
-          } else {
-            withInfo(expr, predicate, trueValue, falseValue)
-            None
-          }
+      case Sin(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("sin", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Sqrt(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("sqrt", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Tan(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("tan", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Ascii(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, input, binding)
+        val optExpr = scalarExprToProto("ascii", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case BitLength(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, input, binding)
+        val optExpr = scalarExprToProto("bit_length", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case If(predicate, trueValue, falseValue) =>
+        val predicateExpr = exprToProtoInternal(predicate, input, binding)
+        val trueExpr = exprToProtoInternal(trueValue, input, binding)
+        val falseExpr = exprToProtoInternal(falseValue, input, binding)
+        if (predicateExpr.isDefined && trueExpr.isDefined && 
falseExpr.isDefined) {
+          val builder = ExprOuterClass.IfExpr.newBuilder()
+          builder.setIfExpr(predicateExpr.get)
+          builder.setTrueExpr(trueExpr.get)
+          builder.setFalseExpr(falseExpr.get)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setIf(builder)
+              .build())
+        } else {
+          withInfo(expr, predicate, trueValue, falseValue)
+          None
+        }
 
-        case CaseWhen(branches, elseValue) =>
-          var allBranches: Seq[Expression] = Seq()
-          val whenSeq = branches.map(elements => {
-            allBranches = allBranches :+ elements._1
-            exprToProtoInternal(elements._1, inputs)
-          })
-          val thenSeq = branches.map(elements => {
-            allBranches = allBranches :+ elements._2
-            exprToProtoInternal(elements._2, inputs)
-          })
-          assert(whenSeq.length == thenSeq.length)
-          if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) {
-            val builder = ExprOuterClass.CaseWhen.newBuilder()
-            builder.addAllWhen(whenSeq.map(_.get).asJava)
-            builder.addAllThen(thenSeq.map(_.get).asJava)
-            if (elseValue.isDefined) {
-              val elseValueExpr =
-                exprToProtoInternal(elseValue.get, inputs)
-              if (elseValueExpr.isDefined) {
-                builder.setElseExpr(elseValueExpr.get)
-              } else {
-                withInfo(expr, elseValue.get)
-                return None
-              }
+      case CaseWhen(branches, elseValue) =>
+        var allBranches: Seq[Expression] = Seq()
+        val whenSeq = branches.map(elements => {
+          allBranches = allBranches :+ elements._1
+          exprToProtoInternal(elements._1, input, binding)
+        })
+        val thenSeq = branches.map(elements => {
+          allBranches = allBranches :+ elements._2
+          exprToProtoInternal(elements._2, input, binding)
+        })
+        assert(whenSeq.length == thenSeq.length)
+        if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) {
+          val builder = ExprOuterClass.CaseWhen.newBuilder()
+          builder.addAllWhen(whenSeq.map(_.get).asJava)
+          builder.addAllThen(thenSeq.map(_.get).asJava)
+          if (elseValue.isDefined) {
+            val elseValueExpr =
+              exprToProtoInternal(elseValue.get, input, binding)
+            if (elseValueExpr.isDefined) {
+              builder.setElseExpr(elseValueExpr.get)
+            } else {
+              withInfo(expr, elseValue.get)
+              return None
             }
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setCaseWhen(builder)
-                .build())
-          } else {
-            withInfo(expr, allBranches: _*)
-            None
-          }
-        case ConcatWs(children) =>
-          var childExprs: Seq[Expression] = Seq()
-          val exprs = children.map(e => {
-            val castExpr = Cast(e, StringType)
-            childExprs = childExprs :+ castExpr
-            exprToProtoInternal(castExpr, inputs)
-          })
-          val optExpr = scalarExprToProto("concat_ws", exprs: _*)
-          optExprWithInfo(optExpr, expr, childExprs: _*)
-
-        case Chr(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("chr", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case InitCap(child) =>
-          if (CometConf.COMET_EXEC_INITCAP_ENABLED.get()) {
-            val castExpr = Cast(child, StringType)
-            val childExpr = exprToProtoInternal(castExpr, inputs)
-            val optExpr = scalarExprToProto("initcap", childExpr)
-            optExprWithInfo(optExpr, expr, castExpr)
-          } else {
-            withInfo(
-              expr,
-              "Comet initCap is not compatible with Spark yet. " +
-                "See https://github.com/apache/datafusion-comet/issues/1052 ." 
+
-                s"Set ${CometConf.COMET_EXEC_INITCAP_ENABLED.key}=true to 
enable it anyway.")
-            None
           }
-
-        case Length(child) =>
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setCaseWhen(builder)
+              .build())
+        } else {
+          withInfo(expr, allBranches: _*)
+          None
+        }
+      case ConcatWs(children) =>
+        var childExprs: Seq[Expression] = Seq()
+        val exprs = children.map(e => {
+          val castExpr = Cast(e, StringType)
+          childExprs = childExprs :+ castExpr
+          exprToProtoInternal(castExpr, input, binding)
+        })
+        val optExpr = scalarExprToProto("concat_ws", exprs: _*)
+        optExprWithInfo(optExpr, expr, childExprs: _*)
+
+      case Chr(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("chr", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case InitCap(child) =>
+        if (CometConf.COMET_EXEC_INITCAP_ENABLED.get()) {
           val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("length", childExpr)
+          val childExpr = exprToProtoInternal(castExpr, input, binding)
+          val optExpr = scalarExprToProto("initcap", childExpr)
           optExprWithInfo(optExpr, expr, castExpr)
+        } else {
+          withInfo(
+            expr,
+            "Comet initCap is not compatible with Spark yet. " +
+              "See https://github.com/apache/datafusion-comet/issues/1052 ." +
+              s"Set ${CometConf.COMET_EXEC_INITCAP_ENABLED.key}=true to enable 
it anyway.")
+          None
+        }
 
-        case Md5(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("md5", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case OctetLength(child) =>
+      case Length(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, input, binding)
+        val optExpr = scalarExprToProto("length", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case Md5(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProto("md5", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case OctetLength(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, input, binding)
+        val optExpr = scalarExprToProto("octet_length", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case Reverse(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, input, binding)
+        val optExpr = scalarExprToProto("reverse", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case StringInstr(str, substr) =>
+        val leftCast = Cast(str, StringType)
+        val rightCast = Cast(substr, StringType)
+        val leftExpr = exprToProtoInternal(leftCast, input, binding)
+        val rightExpr = exprToProtoInternal(rightCast, input, binding)
+        val optExpr = scalarExprToProto("strpos", leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, leftCast, rightCast)
+
+      case StringRepeat(str, times) =>
+        val leftCast = Cast(str, StringType)
+        val rightCast = Cast(times, LongType)
+        val leftExpr = exprToProtoInternal(leftCast, input, binding)
+        val rightExpr = exprToProtoInternal(rightCast, input, binding)
+        val optExpr = scalarExprToProto("repeat", leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, leftCast, rightCast)
+
+      case StringReplace(src, search, replace) =>
+        val srcCast = Cast(src, StringType)
+        val searchCast = Cast(search, StringType)
+        val replaceCast = Cast(replace, StringType)
+        val srcExpr = exprToProtoInternal(srcCast, input, binding)
+        val searchExpr = exprToProtoInternal(searchCast, input, binding)
+        val replaceExpr = exprToProtoInternal(replaceCast, input, binding)
+        val optExpr = scalarExprToProto("replace", srcExpr, searchExpr, 
replaceExpr)
+        optExprWithInfo(optExpr, expr, srcCast, searchCast, replaceCast)
+
+      case StringTranslate(src, matching, replace) =>
+        val srcCast = Cast(src, StringType)
+        val matchingCast = Cast(matching, StringType)
+        val replaceCast = Cast(replace, StringType)
+        val srcExpr = exprToProtoInternal(srcCast, input, binding)
+        val matchingExpr = exprToProtoInternal(matchingCast, input, binding)
+        val replaceExpr = exprToProtoInternal(replaceCast, input, binding)
+        val optExpr = scalarExprToProto("translate", srcExpr, matchingExpr, 
replaceExpr)
+        optExprWithInfo(optExpr, expr, srcCast, matchingCast, replaceCast)
+
+      case StringTrim(srcStr, trimStr) =>
+        trim(expr, srcStr, trimStr, input, binding, "trim")
+
+      case StringTrimLeft(srcStr, trimStr) =>
+        trim(expr, srcStr, trimStr, input, binding, "ltrim")
+
+      case StringTrimRight(srcStr, trimStr) =>
+        trim(expr, srcStr, trimStr, input, binding, "rtrim")
+
+      case StringTrimBoth(srcStr, trimStr, _) =>
+        trim(expr, srcStr, trimStr, input, binding, "btrim")
+
+      case Upper(child) =>
+        if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
           val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("octet_length", childExpr)
+          val childExpr = exprToProtoInternal(castExpr, input, binding)
+          val optExpr = scalarExprToProto("upper", childExpr)
           optExprWithInfo(optExpr, expr, castExpr)
+        } else {
+          withInfo(
+            expr,
+            "Comet is not compatible with Spark for case conversion in " +
+              s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
+              "to enable it anyway.")
+          None
+        }
 
-        case Reverse(child) =>
+      case Lower(child) =>
+        if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
           val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("reverse", childExpr)
+          val childExpr = exprToProtoInternal(castExpr, input, binding)
+          val optExpr = scalarExprToProto("lower", childExpr)
           optExprWithInfo(optExpr, expr, castExpr)
+        } else {
+          withInfo(
+            expr,
+            "Comet is not compatible with Spark for case conversion in " +
+              s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
+              "to enable it anyway.")
+          None
+        }
 
-        case StringInstr(str, substr) =>
-          val leftCast = Cast(str, StringType)
-          val rightCast = Cast(substr, StringType)
-          val leftExpr = exprToProtoInternal(leftCast, inputs)
-          val rightExpr = exprToProtoInternal(rightCast, inputs)
-          val optExpr = scalarExprToProto("strpos", leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, leftCast, rightCast)
-
-        case StringRepeat(str, times) =>
-          val leftCast = Cast(str, StringType)
-          val rightCast = Cast(times, LongType)
-          val leftExpr = exprToProtoInternal(leftCast, inputs)
-          val rightExpr = exprToProtoInternal(rightCast, inputs)
-          val optExpr = scalarExprToProto("repeat", leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, leftCast, rightCast)
-
-        case StringReplace(src, search, replace) =>
-          val srcCast = Cast(src, StringType)
-          val searchCast = Cast(search, StringType)
-          val replaceCast = Cast(replace, StringType)
-          val srcExpr = exprToProtoInternal(srcCast, inputs)
-          val searchExpr = exprToProtoInternal(searchCast, inputs)
-          val replaceExpr = exprToProtoInternal(replaceCast, inputs)
-          val optExpr = scalarExprToProto("replace", srcExpr, searchExpr, 
replaceExpr)
-          optExprWithInfo(optExpr, expr, srcCast, searchCast, replaceCast)
-
-        case StringTranslate(src, matching, replace) =>
-          val srcCast = Cast(src, StringType)
-          val matchingCast = Cast(matching, StringType)
-          val replaceCast = Cast(replace, StringType)
-          val srcExpr = exprToProtoInternal(srcCast, inputs)
-          val matchingExpr = exprToProtoInternal(matchingCast, inputs)
-          val replaceExpr = exprToProtoInternal(replaceCast, inputs)
-          val optExpr = scalarExprToProto("translate", srcExpr, matchingExpr, 
replaceExpr)
-          optExprWithInfo(optExpr, expr, srcCast, matchingCast, replaceCast)
-
-        case StringTrim(srcStr, trimStr) =>
-          trim(expr, srcStr, trimStr, inputs, "trim")
-
-        case StringTrimLeft(srcStr, trimStr) =>
-          trim(expr, srcStr, trimStr, inputs, "ltrim")
-
-        case StringTrimRight(srcStr, trimStr) =>
-          trim(expr, srcStr, trimStr, inputs, "rtrim")
-
-        case StringTrimBoth(srcStr, trimStr, _) =>
-          trim(expr, srcStr, trimStr, inputs, "btrim")
-
-        case Upper(child) =>
-          if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
-            val castExpr = Cast(child, StringType)
-            val childExpr = exprToProtoInternal(castExpr, inputs)
-            val optExpr = scalarExprToProto("upper", childExpr)
-            optExprWithInfo(optExpr, expr, castExpr)
-          } else {
-            withInfo(
-              expr,
-              "Comet is not compatible with Spark for case conversion in " +
-                s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
-                "to enable it anyway.")
-            None
-          }
-
-        case Lower(child) =>
-          if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
-            val castExpr = Cast(child, StringType)
-            val childExpr = exprToProtoInternal(castExpr, inputs)
-            val optExpr = scalarExprToProto("lower", childExpr)
-            optExprWithInfo(optExpr, expr, castExpr)
-          } else {
-            withInfo(
-              expr,
-              "Comet is not compatible with Spark for case conversion in " +
-                s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
-                "to enable it anyway.")
-            None
-          }
-
-        case BitwiseAnd(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
-
-        case BitwiseNot(child) =>
-          createUnaryExpr(child, inputs, (builder, unaryExpr) => 
builder.setBitwiseNot(unaryExpr))
-
-        case BitwiseOr(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr))
-
-        case BitwiseXor(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
-
-        case ShiftRight(left, right) =>
-          // DataFusion bitwise shift right expression requires
-          // same data type between left and right side
-          val rightExpression = if (left.dataType == LongType) {
-            Cast(right, LongType)
-          } else {
-            right
-          }
-
-          createBinaryExpr(
-            left,
-            rightExpression,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr))
-
-        case ShiftLeft(left, right) =>
-          // DataFusion bitwise shift right expression requires
-          // same data type between left and right side
-          val rightExpression = if (left.dataType == LongType) {
-            Cast(right, LongType)
-          } else {
-            right
-          }
-
-          createBinaryExpr(
-            left,
-            rightExpression,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr))
-        case In(value, list) =>
-          in(expr, value, list, inputs, false)
-
-        case InSet(value, hset) =>
-          val valueDataType = value.dataType
-          val list = hset.map { setVal =>
-            Literal(setVal, valueDataType)
-          }.toSeq
-          // Change `InSet` to `In` expression
-          // We do Spark `InSet` optimization in native (DataFusion) side.
-          in(expr, value, list, inputs, false)
-
-        case Not(In(value, list)) =>
-          in(expr, value, list, inputs, true)
-
-        case Not(child) =>
-          createUnaryExpr(child, inputs, (builder, unaryExpr) => 
builder.setNot(unaryExpr))
-
-        case UnaryMinus(child, failOnError) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.UnaryMinus.newBuilder()
-            builder.setChild(childExpr.get)
-            builder.setFailOnError(failOnError)
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setUnaryMinus(builder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
-          }
-
-        case a @ Coalesce(_) =>
-          val exprChildren = a.children.map(exprToProtoInternal(_, inputs))
-          scalarExprToProto("coalesce", exprChildren: _*)
-
-        // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called 
to pad spaces for
-        // char types.
-        // See https://github.com/apache/spark/pull/38151
-        case s: StaticInvoke
-            if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
-              s.dataType.isInstanceOf[StringType] &&
-              s.functionName == "readSidePadding" &&
-              s.arguments.size == 2 &&
-              s.propagateNull &&
-              !s.returnNullable &&
-              s.isDeterministic =>
-          val argsExpr = Seq(
-            exprToProtoInternal(Cast(s.arguments(0), StringType), inputs),
-            exprToProtoInternal(s.arguments(1), inputs))
-
-          if (argsExpr.forall(_.isDefined)) {
-            val builder = ExprOuterClass.ScalarFunc.newBuilder()
-            builder.setFunc("read_side_padding")
-            argsExpr.foreach(arg => builder.addArgs(arg.get))
-
-            
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
-          } else {
-            withInfo(expr, s.arguments: _*)
-            None
-          }
-
-        case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) =>
-          val dataType = serializeDataType(expr.dataType)
-          if (dataType.isEmpty) {
-            withInfo(expr, s"Unsupported datatype ${expr.dataType}")
-            return None
-          }
-          val ex = exprToProtoInternal(expr, inputs)
-          ex.map { child =>
-            val builder = ExprOuterClass.NormalizeNaNAndZero
+      case BitwiseAnd(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          input,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
+
+      case BitwiseNot(child) =>
+        createUnaryExpr(
+          expr,
+          child,
+          input,
+          binding,
+          (builder, unaryExpr) => builder.setBitwiseNot(unaryExpr))
+
+      case BitwiseOr(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          input,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr))
+
+      case BitwiseXor(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          input,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
+
+      case ShiftRight(left, right) =>
+        // DataFusion bitwise shift right expression requires
+        // same data type between left and right side
+        val rightExpression = if (left.dataType == LongType) {
+          Cast(right, LongType)
+        } else {
+          right
+        }
+
+        createBinaryExpr(
+          expr,
+          left,
+          rightExpression,
+          input,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr))
+
+      case ShiftLeft(left, right) =>
+        // DataFusion bitwise shift right expression requires
+        // same data type between left and right side
+        val rightExpression = if (left.dataType == LongType) {
+          Cast(right, LongType)
+        } else {
+          right
+        }
+
+        createBinaryExpr(
+          expr,
+          left,
+          rightExpression,
+          input,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr))
+      case In(value, list) =>
+        in(expr, value, list, input, binding, negate = false)
+
+      case InSet(value, hset) =>
+        val valueDataType = value.dataType
+        val list = hset.map { setVal =>
+          Literal(setVal, valueDataType)
+        }.toSeq
+        // Change `InSet` to `In` expression
+        // We do Spark `InSet` optimization in native (DataFusion) side.
+        in(expr, value, list, input, binding, negate = false)
+
+      case Not(In(value, list)) =>
+        in(expr, value, list, input, binding, negate = true)
+
+      case Not(child) =>
+        createUnaryExpr(
+          expr,
+          child,
+          input,
+          binding,
+          (builder, unaryExpr) => builder.setNot(unaryExpr))
+
+      case UnaryMinus(child, failOnError) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        if (childExpr.isDefined) {
+          val builder = ExprOuterClass.UnaryMinus.newBuilder()
+          builder.setChild(childExpr.get)
+          builder.setFailOnError(failOnError)
+          Some(
+            ExprOuterClass.Expr
               .newBuilder()
-              .setChild(child)
-              .setDatatype(dataType.get)
-            
ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build()
-          }
+              .setUnaryMinus(builder)
+              .build())
+        } else {
+          withInfo(expr, child)
+          None
+        }
 
-        case s @ execution.ScalarSubquery(_, _) if 
supportedDataType(s.dataType) =>
-          val dataType = serializeDataType(s.dataType)
-          if (dataType.isEmpty) {
-            withInfo(s, s"Scalar subquery returns unsupported datatype 
${s.dataType}")
-            return None
-          }
+      case a @ Coalesce(_) =>
+        val exprChildren = a.children.map(exprToProtoInternal(_, input, 
binding))
+        scalarExprToProto("coalesce", exprChildren: _*)
+
+      // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called 
to pad spaces for
+      // char types.
+      // See https://github.com/apache/spark/pull/38151
+      case s: StaticInvoke
+          if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
+            s.dataType.isInstanceOf[StringType] &&
+            s.functionName == "readSidePadding" &&
+            s.arguments.size == 2 &&
+            s.propagateNull &&
+            !s.returnNullable &&
+            s.isDeterministic =>
+        val argsExpr = Seq(
+          exprToProtoInternal(Cast(s.arguments(0), StringType), input, 
binding),
+          exprToProtoInternal(s.arguments(1), input, binding))
+
+        if (argsExpr.forall(_.isDefined)) {
+          val builder = ExprOuterClass.ScalarFunc.newBuilder()
+          builder.setFunc("read_side_padding")
+          argsExpr.foreach(arg => builder.addArgs(arg.get))
+
+          Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
+        } else {
+          withInfo(expr, s.arguments: _*)
+          None
+        }
 
-          val builder = ExprOuterClass.Subquery
+      case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) =>
+        val dataType = serializeDataType(expr.dataType)
+        if (dataType.isEmpty) {
+          withInfo(expr, s"Unsupported datatype ${expr.dataType}")
+          return None
+        }
+        val ex = exprToProtoInternal(expr, input, binding)
+        ex.map { child =>
+          val builder = ExprOuterClass.NormalizeNaNAndZero
             .newBuilder()
-            .setId(s.exprId.id)
+            .setChild(child)
             .setDatatype(dataType.get)
-          Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build())
-
-        case UnscaledValue(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProtoWithReturnType("unscaled_value", 
LongType, childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case MakeDecimal(child, precision, scale, true) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProtoWithReturnType(
-            "make_decimal",
-            DecimalType(precision, scale),
-            childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case b @ BloomFilterMightContain(_, _) =>
-          val bloomFilter = b.left
-          val value = b.right
-          val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs)
-          val valueExpr = exprToProtoInternal(value, inputs)
-          if (bloomFilterExpr.isDefined && valueExpr.isDefined) {
-            val builder = ExprOuterClass.BloomFilterMightContain.newBuilder()
-            builder.setBloomFilter(bloomFilterExpr.get)
-            builder.setValue(valueExpr.get)
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setBloomFilterMightContain(builder)
-                .build())
-          } else {
-            withInfo(expr, bloomFilter, value)
-            None
-          }
+          
ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build()
+        }
 
-        case Murmur3Hash(children, seed) =>
-          val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
-          if (firstUnSupportedInput.isDefined) {
-            withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
-            return None
-          }
-          val exprs = children.map(exprToProtoInternal(_, inputs))
-          val seedBuilder = ExprOuterClass.Literal
-            .newBuilder()
-            .setDatatype(serializeDataType(IntegerType).get)
-            .setIntVal(seed)
-          val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
-          // the seed is put at the end of the arguments
-          scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs 
:+ seedExpr: _*)
-
-        case XxHash64(children, seed) =>
-          val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
-          if (firstUnSupportedInput.isDefined) {
-            withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
-            return None
-          }
-          val exprs = children.map(exprToProtoInternal(_, inputs))
-          val seedBuilder = ExprOuterClass.Literal
-            .newBuilder()
-            .setDatatype(serializeDataType(LongType).get)
-            .setLongVal(seed)
-          val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
-          // the seed is put at the end of the arguments
-          scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ 
seedExpr: _*)
-
-        case Sha2(left, numBits) =>
-          if (!numBits.foldable) {
-            withInfo(expr, "non literal numBits is not supported")
-            return None
-          }
-          // it's possible for spark to dynamically compute the number of bits 
from input
-          // expression, however DataFusion does not support that yet.
-          val childExpr = exprToProtoInternal(left, inputs)
-          val bits = numBits.eval().asInstanceOf[Int]
-          val algorithm = bits match {
-            case 224 => "sha224"
-            case 256 | 0 => "sha256"
-            case 384 => "sha384"
-            case 512 => "sha512"
-            case _ =>
-              null
-          }
-          if (algorithm == null) {
-            exprToProtoInternal(Literal(null, StringType), inputs)
-          } else {
-            scalarExprToProtoWithReturnType(algorithm, StringType, childExpr)
-          }
+      case s @ execution.ScalarSubquery(_, _) if supportedDataType(s.dataType) 
=>
+        val dataType = serializeDataType(s.dataType)
+        if (dataType.isEmpty) {
+          withInfo(s, s"Scalar subquery returns unsupported datatype 
${s.dataType}")
+          return None
+        }
 
-        case struct @ CreateNamedStruct(_) =>
-          if (struct.names.length != struct.names.distinct.length) {
-            withInfo(expr, "CreateNamedStruct with duplicate field names are 
not supported")
-            return None
-          }
+        val builder = ExprOuterClass.Subquery
+          .newBuilder()
+          .setId(s.exprId.id)
+          .setDatatype(dataType.get)
+        Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build())
+
+      case UnscaledValue(child) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProtoWithReturnType("unscaled_value", 
LongType, childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case MakeDecimal(child, precision, scale, true) =>
+        val childExpr = exprToProtoInternal(child, input, binding)
+        val optExpr = scalarExprToProtoWithReturnType(
+          "make_decimal",
+          DecimalType(precision, scale),
+          childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case b @ BloomFilterMightContain(_, _) =>
+        val bloomFilter = b.left
+        val value = b.right
+        val bloomFilterExpr = exprToProtoInternal(bloomFilter, input, binding)
+        val valueExpr = exprToProtoInternal(value, input, binding)
+        if (bloomFilterExpr.isDefined && valueExpr.isDefined) {
+          val builder = ExprOuterClass.BloomFilterMightContain.newBuilder()
+          builder.setBloomFilter(bloomFilterExpr.get)
+          builder.setValue(valueExpr.get)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setBloomFilterMightContain(builder)
+              .build())
+        } else {
+          withInfo(expr, bloomFilter, value)
+          None
+        }
 
-          val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding))
+      case Murmur3Hash(children, seed) =>
+        val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
+        if (firstUnSupportedInput.isDefined) {
+          withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
+          return None
+        }
+        val exprs = children.map(exprToProtoInternal(_, input, binding))
+        val seedBuilder = ExprOuterClass.Literal
+          .newBuilder()
+          .setDatatype(serializeDataType(IntegerType).get)
+          .setIntVal(seed)
+        val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
+        // the seed is put at the end of the arguments
+        scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ 
seedExpr: _*)
+
+      case XxHash64(children, seed) =>
+        val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
+        if (firstUnSupportedInput.isDefined) {
+          withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
+          return None
+        }
+        val exprs = children.map(exprToProtoInternal(_, input, binding))
+        val seedBuilder = ExprOuterClass.Literal
+          .newBuilder()
+          .setDatatype(serializeDataType(LongType).get)
+          .setLongVal(seed)
+        val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
+        // the seed is put at the end of the arguments
+        scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ 
seedExpr: _*)
+
+      case Sha2(left, numBits) =>
+        if (!numBits.foldable) {
+          withInfo(expr, "non literal numBits is not supported")
+          return None
+        }
+        // it's possible for spark to dynamically compute the number of bits 
from input
+        // expression, however DataFusion does not support that yet.
+        val childExpr = exprToProtoInternal(left, input, binding)
+        val bits = numBits.eval().asInstanceOf[Int]
+        val algorithm = bits match {
+          case 224 => "sha224"
+          case 256 | 0 => "sha256"
+          case 384 => "sha384"
+          case 512 => "sha512"
+          case _ =>
+            null
+        }
+        if (algorithm == null) {
+          exprToProtoInternal(Literal(null, StringType), input, binding)
+        } else {
+          scalarExprToProtoWithReturnType(algorithm, StringType, childExpr)
+        }
 
-          if (valExprs.forall(_.isDefined)) {
-            val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
-            structBuilder.addAllValues(valExprs.map(_.get).asJava)
-            structBuilder.addAllNames(struct.names.map(_.toString).asJava)
+      case struct @ CreateNamedStruct(_) =>
+        if (struct.names.length != struct.names.distinct.length) {
+          withInfo(expr, "CreateNamedStruct with duplicate field names are not 
supported")
+          return None
+        }
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setCreateNamedStruct(structBuilder)
-                .build())
-          } else {
-            withInfo(expr, "unsupported arguments for CreateNamedStruct", 
struct.valExprs: _*)
-            None
-          }
+        val valExprs = struct.valExprs.map(exprToProto(_, input, binding))
 
-        case GetStructField(child, ordinal, _) =>
-          exprToProto(child, inputs, binding).map { childExpr =>
-            val getStructFieldBuilder = ExprOuterClass.GetStructField
-              .newBuilder()
-              .setChild(childExpr)
-              .setOrdinal(ordinal)
+        if (valExprs.forall(_.isDefined)) {
+          val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
+          structBuilder.addAllValues(valExprs.map(_.get).asJava)
+          structBuilder.addAllNames(struct.names.map(_.toString).asJava)
 
+          Some(
             ExprOuterClass.Expr
               .newBuilder()
-              .setGetStructField(getStructFieldBuilder)
-              .build()
-          }
+              .setCreateNamedStruct(structBuilder)
+              .build())
+        } else {
+          withInfo(expr, "unsupported arguments for CreateNamedStruct", 
struct.valExprs: _*)
+          None
+        }
 
-        case CreateArray(children, _) =>
-          val childExprs = children.map(exprToProto(_, inputs, binding))
+      case GetStructField(child, ordinal, _) =>
+        exprToProto(child, input, binding).map { childExpr =>
+          val getStructFieldBuilder = ExprOuterClass.GetStructField
+            .newBuilder()
+            .setChild(childExpr)
+            .setOrdinal(ordinal)
 
-          if (childExprs.forall(_.isDefined)) {
-            scalarExprToProto("make_array", childExprs: _*)
-          } else {
-            withInfo(expr, "unsupported arguments for CreateArray", children: 
_*)
-            None
-          }
+          ExprOuterClass.Expr
+            .newBuilder()
+            .setGetStructField(getStructFieldBuilder)
+            .build()
+        }
 
-        case GetArrayItem(child, ordinal, failOnError) =>
-          val childExpr = exprToProto(child, inputs, binding)
-          val ordinalExpr = exprToProto(ordinal, inputs, binding)
+      case CreateArray(children, _) =>
+        val childExprs = children.map(exprToProto(_, input, binding))
 
-          if (childExpr.isDefined && ordinalExpr.isDefined) {
-            val listExtractBuilder = ExprOuterClass.ListExtract
-              .newBuilder()
-              .setChild(childExpr.get)
-              .setOrdinal(ordinalExpr.get)
-              .setOneBased(false)
-              .setFailOnError(failOnError)
+        if (childExprs.forall(_.isDefined)) {
+          scalarExprToProto("make_array", childExprs: _*)
+        } else {
+          withInfo(expr, "unsupported arguments for CreateArray", children: _*)
+          None
+        }
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setListExtract(listExtractBuilder)
-                .build())
-          } else {
-            withInfo(expr, "unsupported arguments for GetArrayItem", child, 
ordinal)
-            None
-          }
+      case GetArrayItem(child, ordinal, failOnError) =>
+        val childExpr = exprToProto(child, input, binding)
+        val ordinalExpr = exprToProto(ordinal, input, binding)
 
-        case expr if expr.prettyName == "array_insert" =>
-          val srcExprProto = exprToProto(expr.children(0), inputs, binding)
-          val posExprProto = exprToProto(expr.children(1), inputs, binding)
-          val itemExprProto = exprToProto(expr.children(2), inputs, binding)
-          val legacyNegativeIndex =
-            
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
-          if (srcExprProto.isDefined && posExprProto.isDefined && 
itemExprProto.isDefined) {
-            val arrayInsertBuilder = ExprOuterClass.ArrayInsert
-              .newBuilder()
-              .setSrcArrayExpr(srcExprProto.get)
-              .setPosExpr(posExprProto.get)
-              .setItemExpr(itemExprProto.get)
-              .setLegacyNegativeIndex(legacyNegativeIndex)
+        if (childExpr.isDefined && ordinalExpr.isDefined) {
+          val listExtractBuilder = ExprOuterClass.ListExtract
+            .newBuilder()
+            .setChild(childExpr.get)
+            .setOrdinal(ordinalExpr.get)
+            .setOneBased(false)
+            .setFailOnError(failOnError)
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setArrayInsert(arrayInsertBuilder)
-                .build())
-          } else {
-            withInfo(
-              expr,
-              "unsupported arguments for ArrayInsert",
-              expr.children(0),
-              expr.children(1),
-              expr.children(2))
-            None
-          }
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setListExtract(listExtractBuilder)
+              .build())
+        } else {
+          withInfo(expr, "unsupported arguments for GetArrayItem", child, 
ordinal)
+          None
+        }
 
-        case ElementAt(child, ordinal, defaultValue, failOnError)
-            if child.dataType.isInstanceOf[ArrayType] =>
-          val childExpr = exprToProto(child, inputs, binding)
-          val ordinalExpr = exprToProto(ordinal, inputs, binding)
-          val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, 
binding))
+      case expr if expr.prettyName == "array_insert" =>
+        val srcExprProto = exprToProto(expr.children(0), input, binding)
+        val posExprProto = exprToProto(expr.children(1), input, binding)
+        val itemExprProto = exprToProto(expr.children(2), input, binding)
+        val legacyNegativeIndex =
+          
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
+        if (srcExprProto.isDefined && posExprProto.isDefined && 
itemExprProto.isDefined) {
+          val arrayInsertBuilder = ExprOuterClass.ArrayInsert
+            .newBuilder()
+            .setSrcArrayExpr(srcExprProto.get)
+            .setPosExpr(posExprProto.get)
+            .setItemExpr(itemExprProto.get)
+            .setLegacyNegativeIndex(legacyNegativeIndex)
 
-          if (childExpr.isDefined && ordinalExpr.isDefined &&
-            defaultExpr.isDefined == defaultValue.isDefined) {
-            val arrayExtractBuilder = ExprOuterClass.ListExtract
+          Some(
+            ExprOuterClass.Expr
               .newBuilder()
-              .setChild(childExpr.get)
-              .setOrdinal(ordinalExpr.get)
-              .setOneBased(true)
-              .setFailOnError(failOnError)
+              .setArrayInsert(arrayInsertBuilder)
+              .build())
+        } else {
+          withInfo(
+            expr,
+            "unsupported arguments for ArrayInsert",
+            expr.children(0),
+            expr.children(1),
+            expr.children(2))
+          None
+        }
 
-            defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
+      case ElementAt(child, ordinal, defaultValue, failOnError)
+          if child.dataType.isInstanceOf[ArrayType] =>
+        val childExpr = exprToProto(child, input, binding)
+        val ordinalExpr = exprToProto(ordinal, input, binding)
+        val defaultExpr = defaultValue.flatMap(exprToProto(_, input, binding))
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setListExtract(arrayExtractBuilder)
-                .build())
-          } else {
-            withInfo(expr, "unsupported arguments for ElementAt", child, 
ordinal)
-            None
-          }
+        if (childExpr.isDefined && ordinalExpr.isDefined &&
+          defaultExpr.isDefined == defaultValue.isDefined) {
+          val arrayExtractBuilder = ExprOuterClass.ListExtract
+            .newBuilder()
+            .setChild(childExpr.get)
+            .setOrdinal(ordinalExpr.get)
+            .setOneBased(true)
+            .setFailOnError(failOnError)
 
-        case GetArrayStructFields(child, _, ordinal, _, _) =>
-          val childExpr = exprToProto(child, inputs, binding)
+          defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
 
-          if (childExpr.isDefined) {
-            val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
+          Some(
+            ExprOuterClass.Expr
               .newBuilder()
-              .setChild(childExpr.get)
-              .setOrdinal(ordinal)
-
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setGetArrayStructFields(arrayStructFieldsBuilder)
-                .build())
-          } else {
-            withInfo(expr, "unsupported arguments for GetArrayStructFields", 
child)
-            None
-          }
-        case expr: ArrayRemove =>
-          if (CometArrayRemove.checkSupport(expr)) {
-            createBinaryExpr(
-              expr.children(0),
-              expr.children(1),
-              inputs,
-              (builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
-          } else {
-            None
-          }
-        case expr if expr.prettyName == "array_contains" =>
-          createBinaryExpr(
-            expr.children(0),
-            expr.children(1),
-            inputs,
-            (builder, binaryExpr) => builder.setArrayContains(binaryExpr))
-        case _ if expr.prettyName == "array_append" =>
-          createBinaryExpr(
-            expr.children(0),
-            expr.children(1),
-            inputs,
-            (builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
-        case _ if expr.prettyName == "array_intersect" =>
-          createBinaryExpr(
-            expr.children(0),
-            expr.children(1),
-            inputs,
-            (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
-        case _ =>
-          withInfo(expr, s"${expr.prettyName} is not supported", 
expr.children: _*)
+              .setListExtract(arrayExtractBuilder)
+              .build())
+        } else {
+          withInfo(expr, "unsupported arguments for ElementAt", child, ordinal)
           None
-      }
-    }
+        }
 
-    /**
-     * Creates a UnaryExpr by calling exprToProtoInternal for the provided 
child expression and
-     * then invokes the supplied function to wrap this UnaryExpr in a 
top-level Expr.
-     *
-     * @param child
-     *   Spark expression
-     * @param inputs
-     *   Inputs to the expression
-     * @param f
-     *   Function that accepts an Expr.Builder and a UnaryExpr and builds the 
specific top-level
-     *   Expr
-     * @return
-     *   Some(Expr) or None if not supported
-     */
-    def createUnaryExpr(
-        child: Expression,
-        inputs: Seq[Attribute],
-        f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => 
ExprOuterClass.Expr.Builder)
-        : Option[ExprOuterClass.Expr] = {
-      val childExpr = exprToProtoInternal(child, inputs)
-      if (childExpr.isDefined) {
-        // create the generic UnaryExpr message
-        val inner = ExprOuterClass.UnaryExpr
-          .newBuilder()
-          .setChild(childExpr.get)
-          .build()
-        // call the user-supplied function to wrap UnaryExpr in a top-level 
Expr
-        // such as Expr.IsNull or Expr.IsNotNull
-        Some(
-          f(
+      case GetArrayStructFields(child, _, ordinal, _, _) =>
+        val childExpr = exprToProto(child, input, binding)
+
+        if (childExpr.isDefined) {
+          val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
+            .newBuilder()
+            .setChild(childExpr.get)
+            .setOrdinal(ordinal)
+
+          Some(
             ExprOuterClass.Expr
-              .newBuilder(),
-            inner).build())
-      } else {
-        withInfo(expr, child)
+              .newBuilder()
+              .setGetArrayStructFields(arrayStructFieldsBuilder)
+              .build())
+        } else {
+          withInfo(expr, "unsupported arguments for GetArrayStructFields", 
child)
+          None
+        }
+      case expr: ArrayRemove => CometArrayRemove.convert(expr, input, binding)

Review Comment:
   Although there is a big diff, I suppose they are just code moving and this 
is the only change?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to