dusantism-db commented on code in PR #50027: URL: https://github.com/apache/spark/pull/50027#discussion_r1966944493
########## sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala: ########## @@ -599,6 +599,116 @@ class CaseStatementExec( } } +/** + * Executable node for SimpleCaseStatement. + * @param caseVariableExec Statement with which all conditionExpressions will be compared to. + * @param conditionExpressions Collection of expressions which correspond to WHEN clauses. + * @param conditionalBodies Collection of executable bodies that have a corresponding condition, + * in WHEN branches. + * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. + * @param session Spark session that SQL script is executed within. + * @param context SqlScriptingExecutionContext keeps the execution state of current script. + */ +class SimpleCaseStatementExec( + caseVariableExec: SingleStatementExec, + conditionExpressions: Seq[Expression], + conditionalBodies: Seq[CompoundBodyExec], + elseBody: Option[CompoundBodyExec], + session: SparkSession, + context: SqlScriptingExecutionContext) extends NonLeafStatementExec { + private object CaseState extends Enumeration { + val Condition, Body = Value + } + + private var state = CaseState.Condition + var bodyExec: Option[CompoundBodyExec] = None + + var conditionBodyTupleIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = _ + private var caseVariableLiteral: Literal = _ + + private var isCacheValid = false + private def validateCache(): Unit = { + if (!isCacheValid) { + val values = caseVariableExec.buildDataFrame(session).collect() + caseVariableExec.isExecuted = true + + caseVariableLiteral = Literal(values.head.get(0)) + conditionBodyTupleIterator = createConditionBodyIterator + isCacheValid = true + } + } + + private def cachedCaseVariableLiteral: Literal = { + validateCache() + caseVariableLiteral + } + + private def cachedConditionBodyIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = { + validateCache() + conditionBodyTupleIterator + } + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = state match { + case CaseState.Condition => cachedConditionBodyIterator.hasNext || elseBody.isDefined + case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext) + } + + override def next(): CompoundStatementExec = state match { + case CaseState.Condition => + cachedConditionBodyIterator.nextOption() + .map { case (condStmt, body) => + if (evaluateBooleanCondition(session, condStmt)) { + bodyExec = Some(body) + state = CaseState.Body + } + condStmt + } + .orElse(elseBody.map { body => { + bodyExec = Some(body) + state = CaseState.Body + next() + }}) + .get + case CaseState.Body => bodyExec.get.getTreeIterator.next() + } + } + + private def createConditionBodyIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = + conditionExpressions.zip(conditionalBodies) + .iterator + .map { case (expr, body) => + val condition = Project( + Seq(Alias(EqualTo(cachedCaseVariableLiteral, expr), "condition")()), + OneRowRelation() + ) + // We hack the Origin to provide more descriptive error messages. For example, if + // the case variable is 1 and the condition expression it's compared to is 5, we + // will get Origin with text "(1 = 5)". + val conditionText = condition.projectList.head.asInstanceOf[Alias].child.toString + val condStmt = new SingleStatementExec( + condition, + Origin(sqlText = Some(conditionText), + startIndex = Some(0), + stopIndex = Some(conditionText.length - 1)), Review Comment: We have tests for these exceptions, they are in SqlScriptingInterpreterSuite. You're right it will screw up line numbers, perhaps it would be better to copy the origin from the original caseVariable expression. This will keep the proper line number, however would just point to one part of the equality expression. For example, if we have this script: ``` BEGIN CASE 1 WHEN NULL THEN SELECT 41; ELSE SELECT 43; END CASE; END ``` The error we would have if we kept the Origin from case variable expression: `{LINE:3} [BOOLEAN_STATEMENT_WITH_EMPTY_ROW] Boolean statement 1 is invalid. Expected single row with a value of the BOOLEAN type, but got an empty row. SQLSTATE: 21000` The error with the hacked origin: `[BOOLEAN_STATEMENT_WITH_EMPTY_ROW] Boolean statement (1 = NULL) is invalid. Expected single row with a value of the BOOLEAN type, but got an empty row. SQLSTATE: 21000` Not sure which is better. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org