dusantism-db commented on code in PR #50027:
URL: https://github.com/apache/spark/pull/50027#discussion_r1966944493


##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -599,6 +599,116 @@ class CaseStatementExec(
   }
 }
 
+/**
+ * Executable node for SimpleCaseStatement.
+ * @param caseVariableExec Statement with which all conditionExpressions will 
be compared to.
+ * @param conditionExpressions Collection of expressions which correspond to 
WHEN clauses.
+ * @param conditionalBodies Collection of executable bodies that have a 
corresponding condition,
+ *                 in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met, 
i.e. ELSE branch.
+ * @param session Spark session that SQL script is executed within.
+ * @param context SqlScriptingExecutionContext keeps the execution state of 
current script.
+ */
+class SimpleCaseStatementExec(
+    caseVariableExec: SingleStatementExec,
+    conditionExpressions: Seq[Expression],
+    conditionalBodies: Seq[CompoundBodyExec],
+    elseBody: Option[CompoundBodyExec],
+    session: SparkSession,
+    context: SqlScriptingExecutionContext) extends NonLeafStatementExec {
+  private object CaseState extends Enumeration {
+    val Condition, Body = Value
+  }
+
+  private var state = CaseState.Condition
+  var bodyExec: Option[CompoundBodyExec] = None
+
+  var conditionBodyTupleIterator: Iterator[(SingleStatementExec, 
CompoundBodyExec)] = _
+  private var caseVariableLiteral: Literal = _
+
+  private var isCacheValid = false
+  private def validateCache(): Unit = {
+    if (!isCacheValid) {
+      val values = caseVariableExec.buildDataFrame(session).collect()
+      caseVariableExec.isExecuted = true
+
+      caseVariableLiteral = Literal(values.head.get(0))
+      conditionBodyTupleIterator = createConditionBodyIterator
+      isCacheValid = true
+    }
+  }
+
+  private def cachedCaseVariableLiteral: Literal = {
+    validateCache()
+    caseVariableLiteral
+  }
+
+  private def cachedConditionBodyIterator: Iterator[(SingleStatementExec, 
CompoundBodyExec)] = {
+    validateCache()
+    conditionBodyTupleIterator
+  }
+
+  private lazy val treeIterator: Iterator[CompoundStatementExec] =
+    new Iterator[CompoundStatementExec] {
+      override def hasNext: Boolean = state match {
+        case CaseState.Condition => cachedConditionBodyIterator.hasNext || 
elseBody.isDefined
+        case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext)
+      }
+
+      override def next(): CompoundStatementExec = state match {
+        case CaseState.Condition =>
+          cachedConditionBodyIterator.nextOption()
+            .map { case (condStmt, body) =>
+              if (evaluateBooleanCondition(session, condStmt)) {
+                bodyExec = Some(body)
+                state = CaseState.Body
+              }
+              condStmt
+            }
+            .orElse(elseBody.map { body => {
+              bodyExec = Some(body)
+              state = CaseState.Body
+              next()
+            }})
+            .get
+        case CaseState.Body => bodyExec.get.getTreeIterator.next()
+      }
+    }
+
+  private def createConditionBodyIterator: Iterator[(SingleStatementExec, 
CompoundBodyExec)] =
+    conditionExpressions.zip(conditionalBodies)
+      .iterator
+      .map { case (expr, body) =>
+        val condition = Project(
+          Seq(Alias(EqualTo(cachedCaseVariableLiteral, expr), "condition")()),
+          OneRowRelation()
+        )
+        // We hack the Origin to provide more descriptive error messages. For 
example, if
+        // the case variable is 1 and the condition expression it's compared 
to is 5, we
+        // will get Origin with text "(1 = 5)".
+        val conditionText = 
condition.projectList.head.asInstanceOf[Alias].child.toString
+        val condStmt = new SingleStatementExec(
+          condition,
+          Origin(sqlText = Some(conditionText),
+            startIndex = Some(0),
+            stopIndex = Some(conditionText.length - 1)),

Review Comment:
   We have tests for these exceptions, they are in SqlScriptingInterpreterSuite.
   You're right it will screw up line numbers, perhaps it would be better to 
copy the origin from the original caseVariable expression. This will keep the 
proper line number, however would just point to one part of the equality 
expression.
   
   For example, if we have this script:
   ```
   BEGIN
     CASE 1
       WHEN NULL THEN
         SELECT 41;
       ELSE
         SELECT 43;
     END CASE;
   END
   ```
   The error we would have if we kept the Origin from case variable expression:
   `{LINE:3} [BOOLEAN_STATEMENT_WITH_EMPTY_ROW] Boolean statement 1 is invalid. 
Expected single row with a value of the BOOLEAN type, but got an empty row. 
SQLSTATE: 21000`
   
   The error with the hacked origin:
   `[BOOLEAN_STATEMENT_WITH_EMPTY_ROW] Boolean statement (1 = NULL) is invalid. 
Expected single row with a value of the BOOLEAN type, but got an empty row. 
SQLSTATE: 21000`
   
   Not sure which is better.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to