cloud-fan commented on code in PR #49414: URL: https://github.com/apache/spark/pull/49414#discussion_r1907041713
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala: ########## @@ -39,3 +40,60 @@ case class SQLFunctionExpression( newChildren: IndexedSeq[Expression]): SQLFunctionExpression = copy(inputs = newChildren) final override val nodePatterns: Seq[TreePattern] = Seq(SQL_FUNCTION_EXPRESSION) } + +/** + * A wrapper for a SQL scalar function expression. This is used for check permissions and + * will be removed in the beginning of the optimization stage. + */ +case class SQLScalarFunction(function: SQLFunction, inputs: Seq[Expression], child: Expression) + extends UnaryExpression with UnaryLike[Expression] with Unevaluable { + override def dataType: DataType = child.dataType + override def toString: String = s"${function.name}(${inputs.mkString(", ")})" + override def sql: String = s"${function.name}(${inputs.map(_.sql).mkString(", ")})" + override protected def withNewChildInternal(newChild: Expression): SQLScalarFunction = { + copy(child = newChild) + } + final override val nodePatterns: Seq[TreePattern] = Seq(SQL_SCALAR_FUNCTION) + // The `inputs` is for display only and does not matter in execution. + override lazy val canonicalized: Expression = copy(inputs = Nil, child = child.canonicalized) + override lazy val deterministic: Boolean = { + function.deterministic.getOrElse(true) && children.forall(_.deterministic) + } +} + +/** + * Provide a way to keep state during analysis for resolving nested SQL functions. + * + * @param nestedSQLFunctionDepth The nested depth in the SQL function resolution. A SQL function + * expression should only be expanded as a [[SQLScalarFunction]] if + * the nested depth is 0. + */ +case class SQLFunctionContext(nestedSQLFunctionDepth: Int = 0) + +object SQLFunctionContext { + + private val value = new ThreadLocal[SQLFunctionContext]() { + override def initialValue: SQLFunctionContext = SQLFunctionContext() + } + + def get: SQLFunctionContext = value.get() + + def reset(): Unit = value.remove() + + private def set(context: SQLFunctionContext): Unit = value.set(context) + + def withSQLFunction[A](f: => A): A = { + val originContext = value.get() + val context = originContext.copy( + nestedSQLFunctionDepth = originContext.nestedSQLFunctionDepth + 1) + set(context) + try f finally { set(originContext) } + } + + def withNewContext[A](f: => A): A = { Review Comment: where do we call it? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org