cloud-fan commented on code in PR #49414: URL: https://github.com/apache/spark/pull/49414#discussion_r1907010675
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala: ########## @@ -2363,6 +2364,285 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + /** + * This rule resolves SQL function expressions. It pulls out function inputs and place them + * in a separate [[Project]] node below the operator and replace the SQL function with its + * actual function body. SQL function expressions in [[Aggregate]] are handled in a special + * way. Non-aggregated SQL functions in the aggregate expressions of an Aggregate need to be + * pulled out into a Project above the Aggregate before replacing the SQL function expressions + * with actual function bodies. For example: + * + * Before: + * Aggregate [c1] [foo(c1), foo(max(c2)), sum(foo(c2)) AS sum] + * +- Relation [c1, c2] + * + * After: + * Project [foo(c1), foo(max_c2), sum] + * +- Aggregate [c1] [c1, max(c2) AS max_c2, sum(foo(c2)) AS sum] + * +- Relation [c1, c2] + */ + object ResolveSQLFunctions extends Rule[LogicalPlan] { + + private def hasSQLFunctionExpression(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(_.isInstanceOf[SQLFunctionExpression]).nonEmpty) + } + + /** + * Check if the function input contains aggregate expressions. + */ + private def checkFunctionInput(f: SQLFunctionExpression): Unit = { + if (f.inputs.exists(AggregateExpression.containsAggregate)) { + // The input of a SQL function should not contain aggregate functions after + // `extractAndRewrite`. If there are aggregate functions, it means they are + // nested in another aggregate function, which is not allowed. + // For example: SELECT sum(foo(sum(c1))) FROM t + // We have to throw the error here because otherwise the query plan after + // resolving the SQL function will not be valid. + throw new AnalysisException( + errorClass = "NESTED_AGGREGATE_FUNCTION", + messageParameters = Map.empty) + } + } + + /** + * Resolve a SQL function expression as a logical plan check if it can be analyzed. + */ + private def resolve(f: SQLFunctionExpression): LogicalPlan = { + // Validate the SQL function input. + checkFunctionInput(f) + val plan = v1SessionCatalog.makeSQLFunctionPlan(f.name, f.function, f.inputs) + val resolved = SQLFunctionContext.withSQLFunction { + // Resolve the SQL function plan using its context. + val conf = new SQLConf() + f.function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k, v) } + SQLConf.withExistingConf(conf) { + executeSameContext(plan) + } + } + // Fail the analysis eagerly if a SQL function cannot be resolved using its input. + SimpleAnalyzer.checkAnalysis(resolved) + resolved + } + + /** + * Rewrite SQL function expressions into actual resolved function bodies and extract + * function inputs into the given project list. + */ + private def rewriteSQLFunctions[E <: Expression]( + expression: E, + projectList: ArrayBuffer[NamedExpression]): E = { + val newExpr = expression match { + case f: SQLFunctionExpression if !hasSQLFunctionExpression(f.inputs) && + // Make sure LateralColumnAliasReference in parameters is resolved and eliminated first. + // Otherwise, the projectList can contain the LateralColumnAliasReference, which will be + // pushed down to a Project without the 'referenced' alias by LCA present, leaving it + // unresolved. + !f.inputs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + withPosition(f) { + val plan = resolve(f) + // Extract the function input project list from the SQL function plan and + // inline the SQL function expression. + plan match { + case Project(body :: Nil, Project(aliases, _: OneRowRelation)) => + val inputs = aliases.map(stripOuterReference) + projectList ++= inputs + SQLScalarFunction(f.function, inputs.map(_.toAttribute), body) + case o => + throw new AnalysisException( + errorClass = "INVALID_SQL_FUNCTION_PLAN_STRUCTURE", + messageParameters = Map("plan" -> o.toString)) + } + } + case o => o.mapChildren(rewriteSQLFunctions(_, projectList)) + } + newExpr.asInstanceOf[E] + } + + /** + * Check if the given expression contains expressions that should be extracted, + * i.e. non-aggregated SQL functions with non-foldable inputs. + */ + private def shouldExtract(e: Expression): Boolean = e match { + // Return false if the expression is already an aggregate expression. + case _: AggregateExpression => false + case _: WindowExpression => false + case _: SQLFunctionExpression => true + case _: LeafExpression => false + case o => o.children.exists(shouldExtract) + } + + /** + * Extract aggregate expressions and window expressions from the given expression and replace + * them with attribute references. + * Example (aggregate): + * Before: foo(c1) + foo(max(c2)) + max(foo(c2)) + * After: foo(c1) + foo(max_c2) + max_foo_c2 + * Extracted expressions: [c1, max(c2) AS max_c2, max(foo(c2)) AS max_foo_c2] Review Comment: This reminds of `Aggregate` normalization we did in `RewriteWithExpression`, which moves the result projection from `Aggregate` and puts it in a new `Project` node above `Aggregate`. It's no harm to do this normalization but for safety we only do it when we have to, like the `With` expression and SQL UDF. ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala: ########## @@ -2363,6 +2364,285 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + /** + * This rule resolves SQL function expressions. It pulls out function inputs and place them + * in a separate [[Project]] node below the operator and replace the SQL function with its + * actual function body. SQL function expressions in [[Aggregate]] are handled in a special + * way. Non-aggregated SQL functions in the aggregate expressions of an Aggregate need to be + * pulled out into a Project above the Aggregate before replacing the SQL function expressions + * with actual function bodies. For example: + * + * Before: + * Aggregate [c1] [foo(c1), foo(max(c2)), sum(foo(c2)) AS sum] + * +- Relation [c1, c2] + * + * After: + * Project [foo(c1), foo(max_c2), sum] + * +- Aggregate [c1] [c1, max(c2) AS max_c2, sum(foo(c2)) AS sum] + * +- Relation [c1, c2] + */ + object ResolveSQLFunctions extends Rule[LogicalPlan] { + + private def hasSQLFunctionExpression(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(_.isInstanceOf[SQLFunctionExpression]).nonEmpty) + } + + /** + * Check if the function input contains aggregate expressions. + */ + private def checkFunctionInput(f: SQLFunctionExpression): Unit = { + if (f.inputs.exists(AggregateExpression.containsAggregate)) { + // The input of a SQL function should not contain aggregate functions after + // `extractAndRewrite`. If there are aggregate functions, it means they are + // nested in another aggregate function, which is not allowed. + // For example: SELECT sum(foo(sum(c1))) FROM t + // We have to throw the error here because otherwise the query plan after + // resolving the SQL function will not be valid. + throw new AnalysisException( + errorClass = "NESTED_AGGREGATE_FUNCTION", + messageParameters = Map.empty) + } + } + + /** + * Resolve a SQL function expression as a logical plan check if it can be analyzed. + */ + private def resolve(f: SQLFunctionExpression): LogicalPlan = { + // Validate the SQL function input. + checkFunctionInput(f) + val plan = v1SessionCatalog.makeSQLFunctionPlan(f.name, f.function, f.inputs) + val resolved = SQLFunctionContext.withSQLFunction { + // Resolve the SQL function plan using its context. + val conf = new SQLConf() + f.function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k, v) } + SQLConf.withExistingConf(conf) { + executeSameContext(plan) + } + } + // Fail the analysis eagerly if a SQL function cannot be resolved using its input. + SimpleAnalyzer.checkAnalysis(resolved) + resolved + } + + /** + * Rewrite SQL function expressions into actual resolved function bodies and extract + * function inputs into the given project list. + */ + private def rewriteSQLFunctions[E <: Expression]( + expression: E, + projectList: ArrayBuffer[NamedExpression]): E = { + val newExpr = expression match { + case f: SQLFunctionExpression if !hasSQLFunctionExpression(f.inputs) && + // Make sure LateralColumnAliasReference in parameters is resolved and eliminated first. + // Otherwise, the projectList can contain the LateralColumnAliasReference, which will be + // pushed down to a Project without the 'referenced' alias by LCA present, leaving it + // unresolved. + !f.inputs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + withPosition(f) { + val plan = resolve(f) + // Extract the function input project list from the SQL function plan and + // inline the SQL function expression. + plan match { + case Project(body :: Nil, Project(aliases, _: OneRowRelation)) => + val inputs = aliases.map(stripOuterReference) + projectList ++= inputs + SQLScalarFunction(f.function, inputs.map(_.toAttribute), body) + case o => + throw new AnalysisException( + errorClass = "INVALID_SQL_FUNCTION_PLAN_STRUCTURE", + messageParameters = Map("plan" -> o.toString)) + } + } + case o => o.mapChildren(rewriteSQLFunctions(_, projectList)) + } + newExpr.asInstanceOf[E] + } + + /** + * Check if the given expression contains expressions that should be extracted, + * i.e. non-aggregated SQL functions with non-foldable inputs. + */ + private def shouldExtract(e: Expression): Boolean = e match { + // Return false if the expression is already an aggregate expression. + case _: AggregateExpression => false + case _: WindowExpression => false + case _: SQLFunctionExpression => true + case _: LeafExpression => false + case o => o.children.exists(shouldExtract) + } + + /** + * Extract aggregate expressions and window expressions from the given expression and replace + * them with attribute references. + * Example (aggregate): + * Before: foo(c1) + foo(max(c2)) + max(foo(c2)) + * After: foo(c1) + foo(max_c2) + max_foo_c2 + * Extracted expressions: [c1, max(c2) AS max_c2, max(foo(c2)) AS max_foo_c2] Review Comment: This reminds me of `Aggregate` normalization we did in `RewriteWithExpression`, which moves the result projection from `Aggregate` and puts it in a new `Project` node above `Aggregate`. It's no harm to do this normalization but for safety we only do it when we have to, like the `With` expression and SQL UDF. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org