allisonwang-db commented on code in PR #49414:
URL: https://github.com/apache/spark/pull/49414#discussion_r1913629542


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala:
##########
@@ -2363,6 +2364,278 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
     }
   }
 
+  /**
+   * This rule resolves SQL function expressions. It pulls out function inputs 
and place them
+   * in a separate [[Project]] node below the operator and replace the SQL 
function with its
+   * actual function body. SQL function expressions in [[Aggregate]] are 
handled in a special
+   * way. Non-aggregated SQL functions in the aggregate expressions of an 
Aggregate need to be
+   * pulled out into a Project above the Aggregate before replacing the SQL 
function expressions
+   * with actual function bodies. For example:
+   *
+   * Before:
+   *   Aggregate [c1] [foo(c1), foo(max(c2)), sum(foo(c2)) AS sum]
+   *   +- Relation [c1, c2]
+   *
+   * After:
+   *   Project [foo(c1), foo(max_c2), sum]
+   *   +- Aggregate [c1] [c1, max(c2) AS max_c2, sum(foo(c2)) AS sum]
+   *      +- Relation [c1, c2]
+   */
+  object ResolveSQLFunctions extends Rule[LogicalPlan] {
+
+    private def hasSQLFunctionExpression(exprs: Seq[Expression]): Boolean = {
+      exprs.exists(_.find(_.isInstanceOf[SQLFunctionExpression]).nonEmpty)
+    }
+
+    /**
+     * Check if the function input contains aggregate expressions.
+     */
+    private def checkFunctionInput(f: SQLFunctionExpression): Unit = {
+      if (f.inputs.exists(AggregateExpression.containsAggregate)) {
+        // The input of a SQL function should not contain aggregate functions 
after
+        // `extractAndRewrite`. If there are aggregate functions, it means 
they are
+        // nested in another aggregate function, which is not allowed.
+        // For example: SELECT sum(foo(sum(c1))) FROM t
+        // We have to throw the error here because otherwise the query plan 
after
+        // resolving the SQL function will not be valid.
+        throw new AnalysisException(
+          errorClass = "NESTED_AGGREGATE_FUNCTION",
+          messageParameters = Map.empty)
+      }
+    }
+
+    /**
+     * Resolve a SQL function expression as a logical plan check if it can be 
analyzed.
+     */
+    private def resolve(f: SQLFunctionExpression): LogicalPlan = {
+      // Validate the SQL function input.
+      checkFunctionInput(f)
+      val plan = v1SessionCatalog.makeSQLFunctionPlan(f.name, f.function, 
f.inputs)
+      val resolved = SQLFunctionContext.withSQLFunction {
+        // Resolve the SQL function plan using its context.
+        val conf = new SQLConf()
+        f.function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k, 
v) }
+        SQLConf.withExistingConf(conf) {
+          executeSameContext(plan)
+        }
+      }
+      // Fail the analysis eagerly if a SQL function cannot be resolved using 
its input.
+      SimpleAnalyzer.checkAnalysis(resolved)
+      resolved
+    }
+
+    /**
+     * Rewrite SQL function expressions into actual resolved function bodies 
and extract
+     * function inputs into the given project list.
+     */
+    private def rewriteSQLFunctions[E <: Expression](
+        expression: E,
+        projectList: ArrayBuffer[NamedExpression]): E = {
+      val newExpr = expression match {
+        case f: SQLFunctionExpression if !hasSQLFunctionExpression(f.inputs) &&
+          // Make sure LateralColumnAliasReference in parameters is resolved 
and eliminated first.
+          // Otherwise, the projectList can contain the 
LateralColumnAliasReference, which will be
+          // pushed down to a Project without the 'referenced' alias by LCA 
present, leaving it
+          // unresolved.
+          !f.inputs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) 
=>
+          withPosition(f) {
+            val plan = resolve(f)
+            // Extract the function input project list from the SQL function 
plan and
+            // inline the SQL function expression.
+            plan match {
+              case Project(body :: Nil, Project(aliases, _: OneRowRelation)) =>
+                val inputs = aliases.map(stripOuterReference)
+                projectList ++= inputs
+                SQLScalarFunction(f.function, inputs.map(_.toAttribute), body)
+              case o =>
+                throw new AnalysisException(
+                  errorClass = "INVALID_SQL_FUNCTION_PLAN_STRUCTURE",
+                  messageParameters = Map("plan" -> o.toString))
+            }
+          }
+        case o => o.mapChildren(rewriteSQLFunctions(_, projectList))
+      }
+      newExpr.asInstanceOf[E]
+    }
+
+    /**
+     * Check if the given expression contains expressions that should be 
extracted,
+     * i.e. non-aggregated SQL functions with non-foldable inputs.
+     */
+    private def shouldExtract(e: Expression): Boolean = e match {
+      // Return false if the expression is already an aggregate expression.
+      case _: AggregateExpression => false
+      case _: SQLFunctionExpression => true
+      case _: LeafExpression => false
+      case o => o.children.exists(shouldExtract)
+    }
+
+    /**
+     * Extract aggregate expressions from the given expression and replace
+     * them with attribute references.
+     * Example:
+     *   Before: foo(c1) + foo(max(c2)) + max(foo(c2))
+     *   After: foo(c1) + foo(max_c2) + max_foo_c2
+     *   Extracted expressions: [c1, max(c2) AS max_c2, max(foo(c2)) AS 
max_foo_c2]
+     */
+    private def extractAndRewrite[T <: Expression](
+        expression: T,
+        extractedExprs: ArrayBuffer[NamedExpression]): T = {
+      val newExpr = expression match {
+        case e if !shouldExtract(e) =>
+          val exprToAdd: NamedExpression = e match {
+            case o: OuterReference => Alias(o, toPrettySQL(o.e))()
+            case ne: NamedExpression => ne
+            case o => Alias(o, toPrettySQL(o))()
+          }
+          extractedExprs += exprToAdd
+          exprToAdd.toAttribute
+        case f: SQLFunctionExpression =>
+          val newInputs = f.inputs.map(extractAndRewrite(_, extractedExprs))
+          f.copy(inputs = newInputs)
+        case o => o.mapChildren(extractAndRewrite(_, extractedExprs))
+      }
+      newExpr.asInstanceOf[T]
+    }
+
+    /**
+     * Replace all [[SQLFunctionExpression]]s in an expression with attribute 
references
+     * from the aliasMap.
+     */
+    private def replaceSQLFunctionWithAttr[T <: Expression](
+        expr: T,
+        aliasMap: mutable.HashMap[Expression, Alias]): T = {
+      expr.transform {
+        case f: SQLFunctionExpression if aliasMap.contains(f.canonicalized) =>
+          aliasMap(f.canonicalized).toAttribute
+      }.asInstanceOf[T]
+    }
+
+    private def rewrite(plan: LogicalPlan): LogicalPlan = plan match {
+      // Return if a sub-tree does not contain SQLFunctionExpression.
+      case p: LogicalPlan if !p.containsPattern(SQL_FUNCTION_EXPRESSION) => p
+
+      case f @ Filter(cond, a: Aggregate)
+        if !f.resolved || AggregateExpression.containsAggregate(cond) ||
+          ResolveGroupingAnalytics.hasGroupingFunction(cond) ||
+          cond.containsPattern(TEMP_RESOLVED_COLUMN) =>
+        // If the filter's condition contains aggregate expressions or 
grouping expressions or temp
+        // resolved column, we cannot rewrite both the filter and the 
aggregate until they are
+        // resolved by ResolveAggregateFunctions or ResolveGroupingAnalytics, 
because rewriting SQL
+        // functions in aggregate can add an additional project on top of the 
aggregate
+        // which breaks the pattern matching in those rules.
+        f.copy(child = a.copy(child = rewrite(a.child)))
+
+      case h @ UnresolvedHaving(_, a: Aggregate) =>
+        // Similarly UnresolvedHaving should be resolved by 
ResolveAggregateFunctions first
+        // before rewriting aggregate.
+        h.copy(child = a.copy(child = rewrite(a.child)))
+
+      case a: Aggregate if a.resolved && 
hasSQLFunctionExpression(a.expressions) =>
+        val child = rewrite(a.child)
+        // Extract SQL functions in the grouping expressions and place them in 
a project list
+        // below the current aggregate. Also update their appearances in the 
aggregate expressions.
+        val bottomProjectList = ArrayBuffer.empty[NamedExpression]
+        val aliasMap = mutable.HashMap.empty[Expression, Alias]
+        val newGrouping = a.groupingExpressions.map { expr =>
+          expr.transformDown {
+            case f: SQLFunctionExpression =>
+              val alias = aliasMap.getOrElseUpdate(f.canonicalized, Alias(f, 
f.name)())
+              bottomProjectList += alias
+              alias.toAttribute
+          }
+        }
+        val aggregateExpressions = a.aggregateExpressions.map(
+          replaceSQLFunctionWithAttr(_, aliasMap))
+
+        // Rewrite SQL functions in the aggregate expressions that are not 
wrapped in
+        // aggregate functions. They need to be extracted into a project list 
above the
+        // current aggregate.
+        val aggExprs = ArrayBuffer.empty[NamedExpression]
+        val topProjectList = aggregateExpressions.map(extractAndRewrite(_, 
aggExprs))

Review Comment:
   Sounds good. We can explore this once we have more test coverage.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to