davidm-db commented on code in PR #50027:
URL: https://github.com/apache/spark/pull/50027#discussion_r1967153704


##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -599,6 +599,116 @@ class CaseStatementExec(
   }
 }
 
+/**
+ * Executable node for SimpleCaseStatement.
+ * @param caseVariableExec Statement with which all conditionExpressions will 
be compared to.
+ * @param conditionExpressions Collection of expressions which correspond to 
WHEN clauses.
+ * @param conditionalBodies Collection of executable bodies that have a 
corresponding condition,
+ *                 in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met, 
i.e. ELSE branch.
+ * @param session Spark session that SQL script is executed within.
+ * @param context SqlScriptingExecutionContext keeps the execution state of 
current script.
+ */
+class SimpleCaseStatementExec(
+    caseVariableExec: SingleStatementExec,
+    conditionExpressions: Seq[Expression],
+    conditionalBodies: Seq[CompoundBodyExec],
+    elseBody: Option[CompoundBodyExec],
+    session: SparkSession,
+    context: SqlScriptingExecutionContext) extends NonLeafStatementExec {
+  private object CaseState extends Enumeration {
+    val Condition, Body = Value
+  }
+
+  private var state = CaseState.Condition
+  var bodyExec: Option[CompoundBodyExec] = None
+
+  var conditionBodyTupleIterator: Iterator[(SingleStatementExec, 
CompoundBodyExec)] = _
+  private var caseVariableLiteral: Literal = _
+
+  private var isCacheValid = false
+  private def validateCache(): Unit = {
+    if (!isCacheValid) {
+      val values = caseVariableExec.buildDataFrame(session).collect()
+      caseVariableExec.isExecuted = true
+
+      caseVariableLiteral = Literal(values.head.get(0))
+      conditionBodyTupleIterator = createConditionBodyIterator
+      isCacheValid = true
+    }
+  }
+
+  private def cachedCaseVariableLiteral: Literal = {
+    validateCache()
+    caseVariableLiteral
+  }
+
+  private def cachedConditionBodyIterator: Iterator[(SingleStatementExec, 
CompoundBodyExec)] = {
+    validateCache()
+    conditionBodyTupleIterator
+  }
+
+  private lazy val treeIterator: Iterator[CompoundStatementExec] =
+    new Iterator[CompoundStatementExec] {
+      override def hasNext: Boolean = state match {
+        case CaseState.Condition => cachedConditionBodyIterator.hasNext || 
elseBody.isDefined
+        case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext)
+      }
+
+      override def next(): CompoundStatementExec = state match {
+        case CaseState.Condition =>
+          cachedConditionBodyIterator.nextOption()
+            .map { case (condStmt, body) =>
+              if (evaluateBooleanCondition(session, condStmt)) {
+                bodyExec = Some(body)
+                state = CaseState.Body
+              }
+              condStmt
+            }
+            .orElse(elseBody.map { body => {
+              bodyExec = Some(body)
+              state = CaseState.Body
+              next()
+            }})
+            .get
+        case CaseState.Body => bodyExec.get.getTreeIterator.next()
+      }
+    }
+
+  private def createConditionBodyIterator: Iterator[(SingleStatementExec, 
CompoundBodyExec)] =
+    conditionExpressions.zip(conditionalBodies)
+      .iterator
+      .map { case (expr, body) =>
+        val condition = Project(
+          Seq(Alias(EqualTo(cachedCaseVariableLiteral, expr), "condition")()),
+          OneRowRelation()
+        )
+        // We hack the Origin to provide more descriptive error messages. For 
example, if
+        // the case variable is 1 and the condition expression it's compared 
to is 5, we
+        // will get Origin with text "(1 = 5)".
+        val conditionText = 
condition.projectList.head.asInstanceOf[Alias].child.toString
+        val condStmt = new SingleStatementExec(
+          condition,
+          Origin(sqlText = Some(conditionText),
+            startIndex = Some(0),
+            stopIndex = Some(conditionText.length - 1)),

Review Comment:
   If it's an easy fix (and it looks like it is) I would definitely go with 
this approach (copying the line number to the hacked origin). If not, I think 
it's better to keep the line number and then improve the messaging later on as 
a follow-up.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to