davidm-db commented on code in PR #50027: URL: https://github.com/apache/spark/pull/50027#discussion_r1964251043
########## sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala: ########## @@ -599,6 +599,116 @@ class CaseStatementExec( } } +/** + * Executable node for SimpleCaseStatement. + * @param caseVariableExec Statement with which all conditionExpressions will be compared to. + * @param conditionExpressions Collection of expressions which correspond to WHEN clauses. + * @param conditionalBodies Collection of executable bodies that have a corresponding condition, + * in WHEN branches. + * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. + * @param session Spark session that SQL script is executed within. + * @param context SqlScriptingExecutionContext keeps the execution state of current script. + */ +class SimpleCaseStatementExec( + caseVariableExec: SingleStatementExec, + conditionExpressions: Seq[Expression], + conditionalBodies: Seq[CompoundBodyExec], + elseBody: Option[CompoundBodyExec], + session: SparkSession, + context: SqlScriptingExecutionContext) extends NonLeafStatementExec { + private object CaseState extends Enumeration { + val Condition, Body = Value + } + + private var state = CaseState.Condition + var bodyExec: Option[CompoundBodyExec] = None + + var conditionBodyTupleIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = _ + private var caseVariableLiteral: Literal = _ + + private var isCacheValid = false + private def validateCache(): Unit = { + if (!isCacheValid) { + val values = caseVariableExec.buildDataFrame(session).collect() + caseVariableExec.isExecuted = true + + caseVariableLiteral = Literal(values.head.get(0)) + conditionBodyTupleIterator = createConditionBodyIterator + isCacheValid = true + } + } + + private def cachedCaseVariableLiteral: Literal = { + validateCache() + caseVariableLiteral + } + + private def cachedConditionBodyIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = { + validateCache() + conditionBodyTupleIterator + } + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = state match { + case CaseState.Condition => cachedConditionBodyIterator.hasNext || elseBody.isDefined + case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext) + } + + override def next(): CompoundStatementExec = state match { + case CaseState.Condition => + cachedConditionBodyIterator.nextOption() + .map { case (condStmt, body) => + if (evaluateBooleanCondition(session, condStmt)) { + bodyExec = Some(body) + state = CaseState.Body + } + condStmt + } + .orElse(elseBody.map { body => { + bodyExec = Some(body) + state = CaseState.Body + next() + }}) + .get + case CaseState.Body => bodyExec.get.getTreeIterator.next() + } + } + + private def createConditionBodyIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = + conditionExpressions.zip(conditionalBodies) + .iterator + .map { case (expr, body) => + val condition = Project( + Seq(Alias(EqualTo(cachedCaseVariableLiteral, expr), "condition")()), + OneRowRelation() + ) + // We hack the Origin to provide more descriptive error messages. For example, if + // the case variable is 1 and the condition expression it's compared to is 5, we + // will get Origin with text "(1 = 5)". + val conditionText = condition.projectList.head.asInstanceOf[Alias].child.toString + val condStmt = new SingleStatementExec( + condition, + Origin(sqlText = Some(conditionText), + startIndex = Some(0), + stopIndex = Some(conditionText.length - 1)), + Map.empty, + isInternal = true, + context = context + ) + (condStmt, body) + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + state = CaseState.Condition Review Comment: `caseVariableExec.reset()` should be added as well? I guess it's not important, but it will reset the `isExecuted` flag if it's ever important for anything -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org