cloud-fan commented on code in PR #49955:
URL: https://github.com/apache/spark/pull/49955#discussion_r1980973366


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/RecursiveCTEExecution.scala:
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalLimit, 
LogicalPlan, Project, Union, UnionLoopRef}
+import org.apache.spark.sql.classic.Dataset
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
+
+
+/**
+ * The physical node for recursion. Currently only UNION ALL case is supported.
+ * For the details about the execution, look at the comment above doExecute 
function.
+ *
+ * A simple recursive query:
+ * {{{
+ * WITH RECURSIVE t(n) AS (
+ *     SELECT 1
+ *     UNION ALL
+ *     SELECT n+1 FROM t WHERE n < 5)
+ * SELECT * FROM t;
+ * }}}
+ * Corresponding logical plan for the recursive query above:
+ * {{{
+ * WithCTE
+ * :- CTERelationDef 0, false
+ * :  +- SubqueryAlias t
+ * :     +- Project [1#0 AS n#3]
+ * :        +- UnionLoop 0
+ * :           :- Project [1 AS 1#0]
+ * :           :  +- OneRowRelation
+ * :           +- Project [(n#1 + 1) AS (n + 1)#2]
+ * :              +- Filter (n#1 < 5)
+ * :                 +- SubqueryAlias t
+ * :                    +- Project [1#0 AS n#1]
+ * :                       +- UnionLoopRef 0, [1#0], false
+ * +- Project [n#3]
+ * +- SubqueryAlias t
+ * +- CTERelationRef 0, true, [n#3], false, false
+ * }}}
+ *
+ * @param loopId This is id of the CTERelationDef containing the recursive 
query. Its value is
+ *               first passed down to UnionLoop when creating it, and then to 
UnionLoopExec in
+ *               SparkStrategies.
+ * @param anchor The logical plan of the initial element of the loop.
+ * @param recursion The logical plan that describes the recursion with an 
[[UnionLoopRef]] node.
+ *                  CTERelationRef, which is marked as recursive, gets 
substituted with
+ *                  [[UnionLoopRef]] in ResolveWithCTE.
+ *                  Both anchor and recursion are marked with @transient 
annotation, so that they
+ *                  are not serialized.
+ * @param output The output attributes of this loop.
+ * @param limit If defined, the total number of rows output by this operator 
will be bounded by
+ *              limit.
+ *              Its value is pushed down to UnionLoop in Optimizer in case 
Limit node is present
+ *              in the logical plan and then transferred to UnionLoopExec in 
SparkStrategies.
+ *              Note here: limit can be applied in the main query calling the 
recursive CTE, and not
+ *              inside the recursive term of recursive CTE.
+ */
+case class UnionLoopExec(
+                          loopId: Long,
+                          @transient anchor: LogicalPlan,
+                          @transient recursion: LogicalPlan,
+                          override val output: Seq[Attribute],
+                          localLimit: Option[Int] = None,
+                          globalLimit: Option[Int] = None) extends 
LeafExecNode {
+
+  override def innerChildren: Seq[QueryPlan[_]] = Seq(anchor, recursion)
+
+  private val numPartitions: Int = conf.defaultNumShufflePartitions
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
+    "numRecursiveLoops" -> SQLMetrics.createMetric(sparkContext, "number of 
recursive loops"))
+
+  private val simpleRecursion = {
+    recursion match {
+      case Project(_, Filter(_, Project(_, UnionLoopRef(_, _, _)))) =>
+        true
+      case Filter(_, Project(_, UnionLoopRef(_, _, _))) =>
+        true
+      case Project(_, Filter(_, UnionLoopRef(_, _, _))) =>
+        true
+      case Project(_, UnionLoopRef(_, _, _)) =>
+        true
+      case _ =>
+        false
+  }
+  }
+  /**
+   * This function executes the plan (optionally with appended limit node) and 
caches the result,
+   * with the caching mode specified in config.
+   */
+  private def executeAndCacheAndCount(
+                                       plan: LogicalPlan, currentLimit: Int) = 
{
+    // In case limit is defined, we create a (local) limit node above the plan 
and execute
+    // the newly created plan.
+    val planOrLimitedPlan = if (globalLimit.isDefined || localLimit.isDefined) 
{
+      LocalLimit(Literal(currentLimit), plan)
+    } else {
+      plan
+    }
+    val df = Dataset.ofRows(session, planOrLimitedPlan)
+    val newDF = {
+      if (!simpleRecursion) {
+        df.repartition(numPartitions)
+      } else {
+        df
+      }
+    }
+    val count = newDF.count()
+    (newDF, count)
+  }
+
+  /**
+   * In the first iteration, anchor term is executed.
+   * Then, in each following iteration, the UnionLoopRef node is substituted 
with the plan from the
+   * previous iteration, and such plan is executed.
+   * After every iteration, the dataframe is repartitioned.
+   * The recursion stops when the generated dataframe is empty, or either the 
limit or
+   * the specified maximum depth from the config is reached.
+   */
+  override protected def doExecute(): RDD[InternalRow] = {
+    val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+    val numOutputRows = longMetric("numOutputRows")
+    val numRecursiveLoops = longMetric("numRecursiveLoops")
+    val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT)
+
+    // currentLimit is initialized from the limit argument, and in each step 
it is decreased by
+    // the number of rows generated in that step.
+    // If limit is not passed down, currentLimit is set to be zero and won't 
be considered in the
+    // condition of while loop down (limit.isEmpty will be true).
+    var globalLimitNum = globalLimit.getOrElse(0)
+    var localLimitNum = localLimit.getOrElse(0)
+    var currentLimit = Math.max(globalLimitNum, localLimitNum * numPartitions)

Review Comment:
   I don't think `localLimitNum * numPartitions` can be treated as a global 
limit. Think about this case:
   1. The query result has 2 partitions. In each iteration, the first partition 
produces 1 million rows, and the second partition produces 1 row.
   2. Let's say the local limit is 100. We need to iterate 100 times so that 
the second partition produces enough data, but the current code stops at the 
first iteration because 1 million already exceeds 100 * 2.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to