davidm-db commented on code in PR #49726: URL: https://github.com/apache/spark/pull/49726#discussion_r1946364577
########## sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala: ########## @@ -1020,3 +1023,105 @@ class ExceptionHandlerExec( override def reset(): Unit = body.reset() } + +/** + * Executable node for Signal Statement. + * @param errorCondition Name of the error condition/SQL State for error that will be thrown. + * @param sqlState SQL State of the error that will be thrown. + * @param message Error message (either string or variable name). + * @param msgArguments Error message parameters for builtin conditions. + * @param session Spark session that SQL script is executed within. + */ +class SignalStatementExec( + val errorCondition: Option[String] = None, + val sqlState: Option[String] = None, + val message: Either[String, UnresolvedAttribute], + val msgArguments: Option[UnresolvedAttribute], + val isBuiltinError: Boolean, + val session: SparkSession, + override val origin: Origin) + extends LeafStatementExec + with ColumnResolutionHelper with WithOrigin { + + override def catalogManager: CatalogManager = session.sessionState.catalogManager + override def conf: SQLConf = session.sessionState.conf + + def getMessageArgs: Map[String, String] = { + msgArguments match { + case Some(args) => + val argsReference = getVariableReference(args, args.nameParts) + + if (!argsReference.dataType.sameType(MapType(StringType, StringType))) { + throw SqlScriptingErrors + .invalidSignalStatementVariableType(CurrentOrigin.get, argsReference.dataType) + } + + val argsValue = argsReference.eval(null) + + if (argsValue == null) { + throw SqlScriptingErrors.nullVariableSignalStatement(CurrentOrigin.get, args.name) + } + + val mapData = argsValue.asInstanceOf[UnsafeMapData] + (0 until mapData.numElements()).map { index => + ( + mapData.keyArray().get(index, StringType).toString, + mapData.valueArray().get(index, StringType).toString + ) + }.toMap + case None => + Map.empty + } + } + + def getMessageText: String = { + message match { + case Left(v) => v + case Right(u) => + val varReference = getVariableReference(u, u.nameParts) + + if (!varReference.dataType.sameType(StringType)) { + throw SqlScriptingErrors + .invalidSignalStatementVariableType(CurrentOrigin.get, varReference.dataType) + } + + // Call eval with null value passed instead of a row. + // This is ok as this is variable and invoking eval should + // be independent of row value. + val varReferenceValue = varReference.eval(null) + + if (varReferenceValue == null) { + throw SqlScriptingErrors + .nullVariableSignalStatement(CurrentOrigin.get, u.name) Review Comment: same thing for `CurrentOrigin.get` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org