dtenedor commented on code in PR #52334:
URL: https://github.com/apache/spark/pull/52334#discussion_r2415179182
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala:
##########
@@ -2200,7 +2200,8 @@ class AstBuilder extends DataTypeAstBuilder
Limit(expression(ctx.expression), query)
case ctx: SampleByPercentileContext =>
- val fraction = ctx.percentage.getText.toDouble
+ val fraction = if (ctx.DECIMAL_VALUE() != null) {
ctx.DECIMAL_VALUE().getText.toDouble }
Review Comment:
you can do this to skip repeating the token:
```
Option(ctx.DECIMAL_VALUE()).map { v =>
v.getText.toDouble
}.getOrElse {
ctx.integerValue().getText.toDouble
}
```
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala:
##########
@@ -234,4 +276,71 @@ trait SparkParserUtils {
}
}
-object SparkParserUtils extends SparkParserUtils
+object SparkParserUtils extends SparkParserUtils {
+
+ /**
+ * Callback type for parameter substitution integration.
+ *
+ * This callback allows the parameter substitution system to provide origin
adjustment without
+ * creating circular dependencies between modules.
+ *
+ * @param startToken
+ * The start token from substituted SQL
+ * @param stopToken
+ * The stop token from substituted SQL
+ * @param substitutedSql
+ * The substituted SQL text
+ * @param objectType
+ * The object type
+ * @param objectName
+ * The object name
+ * @return
+ * Some(origin) if parameter substitution should be applied, None otherwise
+ */
+ type ParameterSubstitutionCallback =
+ (Token, Token, String, Option[String], Option[String]) => Option[Origin]
+
+ /**
+ * Thread-local callback for parameter substitution integration. This is set
by the parameter
+ * handler when parameter substitution occurs.
+ */
+ private val parameterSubstitutionCallbackStorage =
+ new ThreadLocal[Option[ParameterSubstitutionCallback]]() {
+ override def initialValue(): Option[ParameterSubstitutionCallback] = None
+ }
+
+ /**
+ * Get the current parameter substitution callback.
+ */
+ def parameterSubstitutionCallback: Option[ParameterSubstitutionCallback] = {
+ parameterSubstitutionCallbackStorage.get()
+ }
+
+ /**
+ * Set the parameter substitution callback for the current thread. This
should be called by
+ * parameter handlers before parsing.
+ */
+ def setParameterSubstitutionCallback(callback:
ParameterSubstitutionCallback): Unit = {
+ parameterSubstitutionCallbackStorage.set(Some(callback))
+ }
+
+ /**
+ * Clear the parameter substitution callback for the current thread.
+ */
+ def clearParameterSubstitutionCallback(): Unit = {
+ parameterSubstitutionCallbackStorage.remove()
+ }
+
+ /**
+ * Execute a block with a parameter substitution callback set.
+ */
+ def withParameterSubstitutionCallback[T](
Review Comment:
This is not used anywhere in the PR as of now?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParamsParser.scala:
##########
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.antlr.v4.runtime.{CharStreams, CommonTokenStream}
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.internal.SQLConf
+
+
+/**
+ * A parameter substitution parser that replaces parameter markers in SQL text
with their values.
+ * This parser finds parameter markers and substitutes them with provided
values to produce
+ * a modified SQL string ready for execution.
+ */
+class SubstituteParamsParser extends Logging {
+
+ /**
+ * Substitute parameter markers in SQL text with provided values.
+ * Always uses compoundOrSingleStatement parsing which can handle all SQL
constructs.
+ *
+ * @param sqlText The original SQL text containing parameter markers
+ * @param namedParams Map of named parameter values (paramName -> value)
+ * @param positionalParams List of positional parameter values in order
+ * @return A tuple of (modified SQL string with parameters substituted,
+ * number of consumed positional parameters)
+ */
+ def substitute(
+ sqlText: String,
+ namedParams: Map[String, String] = Map.empty,
+ positionalParams: List[String] = List.empty): (String, Int,
PositionMapper) = {
+
+ // Quick pre-check: if there are no parameter markers in the text, skip
parsing entirely
+ if (!sqlText.contains("?") && !sqlText.contains(":")) {
+ return (sqlText, 0, PositionMapper.identity(sqlText))
+ }
+
+ val lexer = new SqlBaseLexer(new
UpperCaseCharStream(CharStreams.fromString(sqlText)))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ // Match main parser configuration for consistent error messages
+ parser.addParseListener(PostProcessor)
+ parser.addParseListener(UnclosedCommentProcessor(sqlText, tokenStream))
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+ parser.legacy_setops_precedence_enabled =
SQLConf.get.setOpsPrecedenceEnforced
+ parser.legacy_exponent_literal_as_decimal_enabled =
SQLConf.get.exponentLiteralAsDecimalEnabled
+ parser.SQL_standard_keyword_behavior = SQLConf.get.enforceReservedKeywords
+ parser.double_quoted_identifiers = SQLConf.get.doubleQuotedIdentifiers
+ parser.parameter_substitution_enabled =
!SQLConf.get.legacyParameterSubstitutionConstantsOnly
+
+ val astBuilder = new SubstituteParmsAstBuilder()
+
+ // Use the same two-stage parsing strategy as the main parser for
consistent error messages
+ val ctx = try {
+ try {
+ // First attempt: SLL mode with bail error strategy
+ parser.setErrorHandler(new SparkParserBailErrorStrategy())
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ parser.compoundOrSingleStatement()
+ } catch {
+ case e: ParseCancellationException =>
+ // Second attempt: LL mode with full error strategy
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+ parser.setErrorHandler(new SparkParserErrorStrategy())
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ parser.compoundOrSingleStatement()
+ }
+ } catch {
+ case e: Throwable => throw e
+ }
+ val parameterLocations = astBuilder.extractParameterLocations(ctx)
+
+ // Substitute parameters in the original text
+ val (substitutedSql, appliedSubstitutions) =
substituteAtLocations(sqlText, parameterLocations,
+ namedParams, positionalParams)
+ val consumedPositionalParams =
parameterLocations.positionalParameterLocations.length
+
+ // Create position mapper for error context translation
+ val positionMapper = PositionMapper(sqlText, substitutedSql,
appliedSubstitutions)
+
+ (substitutedSql, consumedPositionalParams, positionMapper)
+ }
+
+ /**
+ * Detects parameter markers in SQL text without performing substitution.
+ * Always uses compoundOrSingleStatement parsing which can handle all SQL
constructs.
+ *
+ * @param sqlText The original SQL text to analyze
+ * @return A tuple of (hasPositionalParameters, hasNamedParameters)
+ */
+ def detectParameters(sqlText: String): (Boolean, Boolean) = {
+ // Quick pre-check: if there are no parameter markers in the text, skip
parsing entirely
+ if (!sqlText.contains("?") && !sqlText.contains(":")) {
+ return (false, false)
+ }
+
+ val lexer = new SqlBaseLexer(new
UpperCaseCharStream(CharStreams.fromString(sqlText)))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ // Match main parser configuration for consistent error messages
+ parser.addParseListener(PostProcessor)
+ parser.addParseListener(UnclosedCommentProcessor(sqlText, tokenStream))
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+ parser.legacy_setops_precedence_enabled =
SQLConf.get.setOpsPrecedenceEnforced
+ parser.legacy_exponent_literal_as_decimal_enabled =
SQLConf.get.exponentLiteralAsDecimalEnabled
+ parser.SQL_standard_keyword_behavior = SQLConf.get.enforceReservedKeywords
+ parser.double_quoted_identifiers = SQLConf.get.doubleQuotedIdentifiers
+ parser.parameter_substitution_enabled =
!SQLConf.get.legacyParameterSubstitutionConstantsOnly
+
+ val astBuilder = new SubstituteParmsAstBuilder()
+
+ // Use the same two-stage parsing strategy as the main parser for
consistent error messages
+ val ctx = try {
+ try {
+ // First attempt: SLL mode with bail error strategy
+ parser.setErrorHandler(new SparkParserBailErrorStrategy())
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ parser.compoundOrSingleStatement()
+ } catch {
+ case e: ParseCancellationException =>
+ // Second attempt: LL mode with full error strategy
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+ parser.setErrorHandler(new SparkParserErrorStrategy())
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ parser.compoundOrSingleStatement()
+ }
+ } catch {
+ case e: Throwable => throw e
Review Comment:
do we need this try/catch block if we just throw every exception again? We
could just omit it for the same behavior?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParamsParser.scala:
##########
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.antlr.v4.runtime.{CharStreams, CommonTokenStream}
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.internal.SQLConf
+
+
+/**
+ * A parameter substitution parser that replaces parameter markers in SQL text
with their values.
+ * This parser finds parameter markers and substitutes them with provided
values to produce
+ * a modified SQL string ready for execution.
+ */
+class SubstituteParamsParser extends Logging {
+
+ /**
+ * Substitute parameter markers in SQL text with provided values.
+ * Always uses compoundOrSingleStatement parsing which can handle all SQL
constructs.
+ *
+ * @param sqlText The original SQL text containing parameter markers
+ * @param namedParams Map of named parameter values (paramName -> value)
+ * @param positionalParams List of positional parameter values in order
+ * @return A tuple of (modified SQL string with parameters substituted,
+ * number of consumed positional parameters)
+ */
+ def substitute(
+ sqlText: String,
+ namedParams: Map[String, String] = Map.empty,
+ positionalParams: List[String] = List.empty): (String, Int,
PositionMapper) = {
+
+ // Quick pre-check: if there are no parameter markers in the text, skip
parsing entirely
+ if (!sqlText.contains("?") && !sqlText.contains(":")) {
+ return (sqlText, 0, PositionMapper.identity(sqlText))
+ }
+
+ val lexer = new SqlBaseLexer(new
UpperCaseCharStream(CharStreams.fromString(sqlText)))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ // Match main parser configuration for consistent error messages
+ parser.addParseListener(PostProcessor)
+ parser.addParseListener(UnclosedCommentProcessor(sqlText, tokenStream))
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+ parser.legacy_setops_precedence_enabled =
SQLConf.get.setOpsPrecedenceEnforced
+ parser.legacy_exponent_literal_as_decimal_enabled =
SQLConf.get.exponentLiteralAsDecimalEnabled
+ parser.SQL_standard_keyword_behavior = SQLConf.get.enforceReservedKeywords
+ parser.double_quoted_identifiers = SQLConf.get.doubleQuotedIdentifiers
+ parser.parameter_substitution_enabled =
!SQLConf.get.legacyParameterSubstitutionConstantsOnly
+
+ val astBuilder = new SubstituteParmsAstBuilder()
+
+ // Use the same two-stage parsing strategy as the main parser for
consistent error messages
+ val ctx = try {
+ try {
+ // First attempt: SLL mode with bail error strategy
+ parser.setErrorHandler(new SparkParserBailErrorStrategy())
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ parser.compoundOrSingleStatement()
+ } catch {
+ case e: ParseCancellationException =>
+ // Second attempt: LL mode with full error strategy
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+ parser.setErrorHandler(new SparkParserErrorStrategy())
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ parser.compoundOrSingleStatement()
+ }
+ } catch {
+ case e: Throwable => throw e
Review Comment:
This just re-throws the exception unchanged; do we need the try/catch block
at all? The behavior could be the same without it?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParamsParser.scala:
##########
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.antlr.v4.runtime.{CharStreams, CommonTokenStream}
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.internal.SQLConf
+
+
+/**
+ * A parameter substitution parser that replaces parameter markers in SQL text
with their values.
+ * This parser finds parameter markers and substitutes them with provided
values to produce
+ * a modified SQL string ready for execution.
+ */
+class SubstituteParamsParser extends Logging {
+
+ /**
+ * Substitute parameter markers in SQL text with provided values.
+ * Always uses compoundOrSingleStatement parsing which can handle all SQL
constructs.
+ *
+ * @param sqlText The original SQL text containing parameter markers
+ * @param namedParams Map of named parameter values (paramName -> value)
+ * @param positionalParams List of positional parameter values in order
+ * @return A tuple of (modified SQL string with parameters substituted,
+ * number of consumed positional parameters)
+ */
+ def substitute(
+ sqlText: String,
Review Comment:
please fix formatting here?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala:
##########
@@ -2257,7 +2258,7 @@ class AstBuilder extends DataTypeAstBuilder
override def visitVersion(ctx: VersionContext): Option[String] = {
if (ctx != null) {
- if (ctx.INTEGER_VALUE != null) {
+ if (ctx.INTEGER_VALUE() != null) {
Review Comment:
same here, you can do:
```
Option(ctx.INTEGER_VALUE()).map { v =>
Some(v.getText)
}.getOrElse {
Option(string(visitStringLit(ctx.stringLit())))
}
```
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala:
##########
@@ -234,4 +276,71 @@ trait SparkParserUtils {
}
}
-object SparkParserUtils extends SparkParserUtils
+object SparkParserUtils extends SparkParserUtils {
+
+ /**
+ * Callback type for parameter substitution integration.
+ *
+ * This callback allows the parameter substitution system to provide origin
adjustment without
+ * creating circular dependencies between modules.
+ *
+ * @param startToken
+ * The start token from substituted SQL
+ * @param stopToken
+ * The stop token from substituted SQL
+ * @param substitutedSql
+ * The substituted SQL text
+ * @param objectType
+ * The object type
+ * @param objectName
+ * The object name
+ * @return
+ * Some(origin) if parameter substitution should be applied, None otherwise
+ */
+ type ParameterSubstitutionCallback =
+ (Token, Token, String, Option[String], Option[String]) => Option[Origin]
+
+ /**
+ * Thread-local callback for parameter substitution integration. This is set
by the parameter
+ * handler when parameter substitution occurs.
+ */
+ private val parameterSubstitutionCallbackStorage =
+ new ThreadLocal[Option[ParameterSubstitutionCallback]]() {
Review Comment:
I wonder if there is any way we can set the callback using the parsing API
rather than a thread local like this. This could make it more complex to
support parallel parsing later, for example. If the parsing entry point is the
`parsePlan` method of AbstractSqlParser, that introduces a call site with stack
frames creating state, currently starting with the SQL string to parse. Is
there any way we could pass this callback as part of that state into the stack
frames instead of using a thread local?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala:
##########
@@ -90,15 +90,17 @@ abstract class AbstractSqlParser extends AbstractParser
with ParserInterface {
}
/** Creates LogicalPlan for a given SQL string. */
- override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) {
parser =>
- val ctx = parser.compoundOrSingleStatement()
- withErrorHandling(ctx, Some(sqlText)) {
- astBuilder.visitCompoundOrSingleStatement(ctx) match {
- case compoundBody: CompoundPlanStatement => compoundBody
- case plan: LogicalPlan => plan
- case _ =>
- val position = Origin(None, None)
- throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText,
position)
+ override def parsePlan(sqlText: String): LogicalPlan = {
Review Comment:
It looks like this part of the code in the file has not changed? If not,
let's revert it to simplify the PR and help make it easier for people's Spark
forks (if any) to merge it later?
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala:
##########
@@ -36,7 +36,51 @@ case class Origin(
objectType: Option[String] = None,
objectName: Option[String] = None,
stackTrace: Option[Array[StackTraceElement]] = None,
- pysparkErrorContext: Option[(String, String)] = None) {
+ pysparkErrorContext: Option[(String, String)] = None,
+ parameterSubstitutionCallback: Option[Any] = None) { // Store callback to
avoid dependencies.
Review Comment:
I am wondering what dependencies we have to avoid that make it necessary to
pass this as a callback instead of just invoking the necessary code directly?
The parser lives in the Catalyst package, and it should be possible to import
and directly call/use any other code in the Catalyst package. Can we do that
instead of passing a callback around? It would simplify the logic a lot.
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala:
##########
@@ -234,4 +302,15 @@ trait SparkParserUtils {
}
}
-object SparkParserUtils extends SparkParserUtils
+object SparkParserUtils extends SparkParserUtils {
+
+ /**
+ * Type alias for parameter substitution callback function. Takes
(startToken, stopToken,
+ * substitutedSql, objectType, objectName) and returns an optional Origin.
+ */
+ type ParameterSubstitutionCallback =
+ (Token, Token, String, Option[String], Option[String]) => Option[Origin]
+
+ // Note: Thread-local callback mechanism removed - parameter substitution
info
Review Comment:
OK, it is good that we're not using a ThreadLocal to hold the callback
anymore because that's a global variable that is brittle to threading changes.
We can remove this comment now :)
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala:
##########
@@ -187,25 +187,93 @@ trait SparkParserUtils {
* Register the origin of the context. Any TreeNode created in the closure
will be assigned the
* registered origin. This method restores the previously set origin after
completion of the
* closure.
+ *
+ * This method is parameter substitution-aware. If parameter substitution
occurred before
+ * parsing, it will automatically adjust the positions and SQL text to refer
to the original SQL
+ * (before substitution) instead of the substituted SQL.
*/
def withOrigin[T](ctx: ParserRuleContext, sqlText: Option[String] = None)(f:
=> T): T = {
val current = CurrentOrigin.get
val text = sqlText.orElse(current.sqlText)
+
if (text.isEmpty) {
CurrentOrigin.set(position(ctx.getStart))
} else {
- CurrentOrigin.set(
- positionAndText(
- ctx.getStart,
- ctx.getStop,
- text.get,
- current.objectType,
- current.objectName))
+ // Check if parameter substitution occurred and adjust origin
accordingly.
+ val adjustedOrigin = adjustOriginForParameterSubstitution(
+ ctx.getStart,
+ ctx.getStop,
+ text.get,
+ current.objectType,
+ current.objectName)
+
+ // Preserve any existing callback when setting the new origin.
+ val finalOrigin = if (current.parameterSubstitutionCallback.isDefined) {
+ adjustedOrigin.copy(parameterSubstitutionCallback =
current.parameterSubstitutionCallback)
+ } else {
+ adjustedOrigin
+ }
+
+ CurrentOrigin.set(finalOrigin)
}
try {
f
} finally {
- CurrentOrigin.set(current)
+ // When restoring origin, preserve any callback that was added during
parsing.
+ val currentAfterParsing = CurrentOrigin.get
+ val originToRestore =
+ if (currentAfterParsing.parameterSubstitutionCallback.isDefined ||
+ current.parameterSubstitutionCallback.isDefined) {
+ // Either the current or the original has a callback - preserve it.
+ val callbackToPreserve =
currentAfterParsing.parameterSubstitutionCallback
+ .orElse(current.parameterSubstitutionCallback)
+ current.copy(parameterSubstitutionCallback = callbackToPreserve)
+ } else {
+ // Neither has a callback - restore as normal.
+ current
+ }
+
+ CurrentOrigin.set(originToRestore)
+ }
+ }
+
+ /**
+ * Adjust origin information to account for parameter substitution.
+ *
+ * If parameter substitution occurred, this method maps positions from the
substituted SQL back
+ * to the original SQL and uses the original SQL text in the origin.
+ *
+ * @param startToken
+ * The start token from the substituted SQL
+ * @param stopToken
+ * The stop token from the substituted SQL
+ * @param substitutedSql
+ * The SQL text after substitution
+ * @param objectType
+ * The object type for the origin
+ * @param objectName
+ * The object name for the origin
+ * @return
+ * Origin with positions and text adjusted for parameter substitution
+ */
+ private def adjustOriginForParameterSubstitution(
+ startToken: Token,
+ stopToken: Token,
+ substitutedSql: String,
+ objectType: Option[String],
+ objectName: Option[String]): Origin = {
+
+ // Try to get parameter substitution callback from CurrentOrigin.
+ CurrentOrigin.get.parameterSubstitutionCallback match {
+ case Some(callback) =>
+ // Cast the callback from Any to the proper type.
+ val typedCallback =
callback.asInstanceOf[SparkParserUtils.ParameterSubstitutionCallback]
Review Comment:
can we just store it as an
`Option[SparkParserUtils.ParameterSubstitutionCallback]` instead of
`Option[Any]` for better type safety?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala:
##########
@@ -49,10 +53,132 @@ class SparkSqlParser extends AbstractSqlParser {
val astBuilder = new SparkSqlAstBuilder()
private val substitutor = new VariableSubstitution()
+ private[execution] val parameterHandler = new ParameterHandler()
+
+ // Thread-local flag to track whether we're in a top-level parse operation
+ // This is used to prevent parameter substitution during identifier/data
type parsing
+ private val isTopLevelParse = new ThreadLocal[Boolean] {
Review Comment:
This is also introducing a new ThreadLocal. Can we avoid it?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/LiteralToSqlConverter.scala:
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.{InternalRow}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types._
+
+/**
+ * Utility for converting Catalyst literal expressions to their SQL string
representation.
+ *
+ * This object provides a specialized implementation for converting Spark SQL
literal
+ * expressions to their equivalent SQL text representation. It is used by the
parameter
+ * substitution system for EXECUTE IMMEDIATE and other parameterized queries.
+ *
+ * Key features:
+ * - Handles all Spark SQL data types for literal values
+ * - Supports both Scala collections and Spark internal data structures
+ * - Proper SQL escaping and formatting
+ * - Optimized for literal expressions only
+ *
+ * Supported data types:
+ * - Primitives: String, Integer, Long, Float, Double, Boolean, Decimal
+ * - Temporal: Date, Timestamp, TimestampNTZ, Interval
+ * - Collections: Array, Map (including nested structures)
+ * - Special: Null values, Binary data
+ * - Complex: Nested arrays, maps of arrays, arrays of maps
+ *
+ * @example Basic usage:
+ * {{{
+ * val result1 = LiteralToSqlConverter.convert(Literal(42))
+ * // result1: "42"
+ *
+ * val result2 = LiteralToSqlConverter.convert(Literal("hello"))
+ * // result2: "'hello'"
+ *
+ * val arrayLit = Literal.create(Array(1, 2, 3), ArrayType(IntegerType))
+ * val result3 = LiteralToSqlConverter.convert(arrayLit)
+ * // result3: "ARRAY(1, 2, 3)"
+ * }}}
+ *
+ * @example Complex types:
+ * {{{
+ * val mapLit = Literal.create(Map("key" -> "value"), MapType(StringType,
StringType))
+ * val result = LiteralToSqlConverter.convert(mapLit)
+ * // result: "MAP('key', 'value')"
+ * }}}
+ *
+ * @note This utility is thread-safe and can be used concurrently.
+ * @note Only supports Literal expressions - all parameter values must be
pre-evaluated.
+ * @see [[ParameterHandler]] for the main parameter handling entry point
+ * @since 4.0.0
+ */
+object LiteralToSqlConverter {
+
+ /**
+ * Convert an expression to its SQL string representation.
+ *
+ * This method handles both simple literals and complex expressions that
result from
+ * parameter evaluation. For complex types like arrays and maps, the
expressions are
+ * evaluated to internal data structures that need to be converted back to
SQL constructors.
+ *
+ * @param expr The expression to convert (typically a Literal, but may be
other expressions
+ * for complex types)
+ * @return SQL string representation of the expression value
+ */
+ def convert(expr: Expression): String = expr match {
+ case lit: Literal => convertLiteral(lit)
+
+ // Special handling for UnresolvedFunction expressions that don't
naturally evaluate
+ // Only handle functions that are whitelisted in legacy mode but don't
eval() naturally
+ case UnresolvedFunction(name, children, _, _, _, _, _) =>
+ val functionName = name.mkString(".")
+ functionName.toLowerCase(java.util.Locale.ROOT) match {
+ case "array" | "map" | "struct" | "map_from_arrays" |
"map_from_entries" =>
+ // Convert whitelisted functions to SQL function call syntax
+ val childrenSql = children.map(convert).mkString(", ")
+ s"${functionName.toUpperCase(java.util.Locale.ROOT)}($childrenSql)"
+ case _ =>
+ // Non-whitelisted function - not supported in parameter substitution
+ throw QueryCompilationErrors.unsupportedParameterExpression(expr)
+ }
+
+ case _ =>
+ // For non-literal expressions, they should be resolved before reaching
this converter
+ // If we get an unresolved expression, it indicates a problem in the
calling code
+ if (!expr.resolved) {
+ throw SparkException.internalError(
+ s"LiteralToSqlConverter received unresolved expression: " +
+ s"${expr.getClass.getSimpleName}. All expressions should be resolved
before " +
+ s"parameter conversion.")
+ }
+ if (expr.foldable) {
+ val value = expr.eval()
+ val dataType = expr.dataType
+ convertLiteral(Literal.create(value, dataType))
+ } else {
+ throw SparkException.internalError(
+ s"LiteralToSqlConverter cannot convert non-foldable expression: " +
+ s"${expr.getClass.getSimpleName}. All parameter values should be
evaluable to " +
+ s"literals before conversion.")
+ }
+ }
+
+ private def convertLiteral(lit: Literal): String = {
+ // For simple cases, delegate to the existing Literal.sql method
+ // which already has the correct logic for most data types
+ try {
+ lit.sql
+ } catch {
+ case _: Exception =>
Review Comment:
This catches all possible exceptions and throws away their contents. It
could be dangerous if it runs into a new exception type we don't expect in the
future and we lose the information. Could we make this catch more specific
subclasses of exceptions instead for better safety?
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala:
##########
@@ -57,6 +101,44 @@ case class Origin(
*/
trait WithOrigin {
def origin: Origin
+
+ /**
+ * Update query contexts in this object with translated positions. Uses
reflection to
+ * generically update any object with query context.
+ */
+ def updateQueryContext(translator: Array[QueryContext] =>
Array[QueryContext]): WithOrigin = {
+ try {
+ val thisClass = this.getClass
+
+ // Try to find query context using common method names
+ val contextMethodNames = Seq("getQueryContext", "context")
+
+ for (methodName <- contextMethodNames) {
+ try {
+ val getMethod = thisClass.getMethod(methodName)
+ val currentContexts =
getMethod.invoke(this).asInstanceOf[Array[QueryContext]]
+ val translatedContexts = translator(currentContexts)
+
+ // Try to update the field in-place
+ val fieldName = if (methodName == "getQueryContext") "queryContext"
else "context"
+ try {
+ val field = thisClass.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ field.set(this, translatedContexts)
+ return this // Successfully updated in-place!
+ } catch {
+ case _: Exception => // Field update failed, continue to try other
methods
+ }
+ } catch {
+ case _: Exception => // Method doesn't exist, try next
+ }
+ }
+
+ this // No update possible, return unchanged
+ } catch {
+ case _: Exception => this // Return unchanged if anything fails
Review Comment:
This is catching all possible exceptions, and dropping the contents of the
exceptions. This could be dangerous if something unexpected happens and we lose
the contents of the exception to know what it was. Could we possibly catch more
specific exception subclasses instead for better safety, letting everything not
explicitly enumerated fall through instead?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParameterHandler.scala:
##########
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import scala.util.{Failure, Success, Try}
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.util.LiteralToSqlConverter
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+/**
+ * Handler for parameter substitution across different Spark SQL contexts.
+ *
+ * This class consolidates the common parameter handling logic used by
SparkSqlParser,
+ * SparkConnectPlanner, and ExecuteImmediate. It provides a single, consistent
API
+ * for all parameter substitution operations in Spark SQL.
+ *
+ * Key features:
+ * - Automatic parameter type detection (named vs positional)
+ * - Uses CompoundOrSingleStatement parsing for all SQL constructs
+ * - Consistent error handling and validation
+ * - Support for complex data types (arrays, maps, nested structures)
+ * - Thread-safe operations with position-aware error context
+ *
+ * The handler integrates with the parser through callback mechanisms stored in
+ * CurrentOrigin to ensure error positions are correctly mapped back to the
original SQL text.
+ *
+ * @example Basic usage:
+ * {{{
+ * val handler = new ParameterHandler()
+ * val context = NamedParameterContext(Map("param1" -> Literal(42)))
+ * val result = handler.substituteParameters("SELECT :param1", context)
+ * // result: "SELECT 42"
+ * }}}
+ *
+ * @example Optional context:
+ * {{{
+ * val handler = new ParameterHandler()
+ * val context = Some(NamedParameterContext(Map("param1" -> Literal(42))))
+ * val result = handler.substituteParametersIfNeeded("SELECT :param1", context)
+ * // result: "SELECT 42"
+ * }}}
+ *
+ * @see [[SubstituteParamsParser]] for the underlying parameter substitution
logic
+ */
+class ParameterHandler {
+
+ // Compiled regex pattern for efficient parameter marker detection.
+ private val parameterMarkerPattern = java.util.regex.Pattern.compile("[?:]")
+
+
+ /**
+ * Helper method to perform parameter substitution and store position mapper.
+ *
+ * @param sqlText The SQL text containing parameter markers
+ * @param namedParams Optional named parameters map
+ * @param positionalParams Optional positional parameters list
+ * @return The SQL text with parameters substituted
+ */
+ private def performSubstitution(
+ sqlText: String,
+ namedParams: Map[String, String] = Map.empty,
+ positionalParams: List[String] = List.empty): String = {
+
+ // Quick pre-check: if there are no parameter markers in the text, skip
parsing entirely.
+ if (!parameterMarkerPattern.matcher(sqlText).find()) {
+ val identityMapper = PositionMapper.identity(sqlText)
+ setupSubstitutionContext(sqlText, sqlText, identityMapper, isIdentity =
true)
+ return sqlText
+ }
+
+ val substitutor = new SubstituteParamsParser()
+ val (substituted, _, positionMapper) = substitutor.substitute(sqlText,
+ namedParams = namedParams, positionalParams = positionalParams)
+
+ setupSubstitutionContext(sqlText, substituted, positionMapper, isIdentity
= false)
+ substituted
+ }
+
+ /**
+ * Set up position mapping context for error reporting.
+ * Creates a callback function and stores it in CurrentOrigin for position
translation
+ * between original and substituted SQL text.
+ *
+ * @param originalSql The original SQL text before parameter substitution
+ * @param substitutedSql The SQL text after parameter substitution
+ * @param positionMapper The position mapper for translating error positions
+ * @param isIdentity Whether this is an identity mapping (no substitution
occurred)
+ */
+ private[sql] def setupSubstitutionContext(
+ originalSql: String,
+ substitutedSql: String,
+ positionMapper: PositionMapper,
+ isIdentity: Boolean): Unit = {
+
+ // Create callback function for position mapping.
+ // The positionMapper is captured in the closure for efficient position
translation.
+ val callback: (org.antlr.v4.runtime.Token, org.antlr.v4.runtime.Token,
String,
Review Comment:
If I understand correctly, this is the only actual parameter substitution
callback we're proposing to pass in. This code is in the Catalyst package;
could we just call it directly instead? That could simplify a lot of things.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]