dtenedor commented on code in PR #49571: URL: https://github.com/apache/spark/pull/49571#discussion_r1931294394
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala: ########## @@ -848,6 +848,15 @@ object LimitPushDown extends Rule[LogicalPlan] { case LocalLimit(exp, u: Union) => LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) + // If limit node is present, we should propagate it down to UnionLoop, so that it is later + // propagated to UnionLoopExec. + // Limit node is constructed by placing GlobalLimit over LocalLimit (look at Limit apply method) + // that is the reason why we match it this way. + case g @ GlobalLimit(IntegerLiteral(limit), l @ LocalLimit(_, p @ Project(_, ul: UnionLoop))) => Review Comment: Should we add helpers like `LimitAndOffset` [1] to match against global + local limits like this? [1] https://github.com/apache/spark/blob/fef1b2375c3074cb3b53d5c29df1aa27c269469c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala#L1605 ########## common/utils/src/main/resources/error/error-conditions.json: ########## @@ -5059,6 +5059,12 @@ ], "sqlState" : "42846" }, + "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE" : { + "message" : [ + "UNION operator not yet supported in recursive CTEs. Use UNION ALL." Review Comment: ```suggestion "The UNION operator is not yet supported within recursive common table expressions (WITH clauses that refer to themselves, directly or indirectly). Please use UNION ALL instead." ``` ########## sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala: ########## @@ -714,6 +717,133 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { copy(children = newChildren) } +/** + * The physical node for recursion. Currently only UNION ALL case is supported. + * In the first iteration, anchor term is executed. + * Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the + * previous iteration, and such plan is executed. + * After every iteration, the dataframe is repartitioned. + * The recursion stops when the generated dataframe is empty, or either the limit or + * the specified maximum depth from the config is reached. + * + * @param loopId The id of the loop. Review Comment: also mention what this is used for, and who sets it and references it? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala: ########## @@ -714,6 +717,133 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { copy(children = newChildren) } +/** + * The physical node for recursion. Currently only UNION ALL case is supported. + * In the first iteration, anchor term is executed. + * Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the + * previous iteration, and such plan is executed. + * After every iteration, the dataframe is repartitioned. + * The recursion stops when the generated dataframe is empty, or either the limit or + * the specified maximum depth from the config is reached. + * + * @param loopId The id of the loop. + * @param anchor The logical plan of the initial element of the loop. + * @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node. + * @param output The output attributes of this loop. + * @param limit In case we have a plan with the limit node, it is pushed down to UnionLoop and then + * transferred to UnionLoopExec, to stop the recursion after specific amount of rows + * is generated. + * Note here: limit can be applied in the main query calling the recursive CTE, and not + * inside the recursive term of recursive CTE. + */ +case class UnionLoopExec( + loopId: Long, + @transient anchor: LogicalPlan, Review Comment: why is this transient? ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala: ########## @@ -1031,6 +1038,9 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { p } + // TODO: Pruning `UnionLoop`s needs to take into account both the outer `Project` and the inner + // `UnionLoopRef` nodes. Review Comment: Why do we drop the `UnionLoop` here? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala: ########## @@ -714,6 +717,133 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { copy(children = newChildren) } +/** + * The physical node for recursion. Currently only UNION ALL case is supported. + * In the first iteration, anchor term is executed. + * Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the + * previous iteration, and such plan is executed. + * After every iteration, the dataframe is repartitioned. + * The recursion stops when the generated dataframe is empty, or either the limit or + * the specified maximum depth from the config is reached. + * + * @param loopId The id of the loop. + * @param anchor The logical plan of the initial element of the loop. + * @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node. + * @param output The output attributes of this loop. + * @param limit In case we have a plan with the limit node, it is pushed down to UnionLoop and then + * transferred to UnionLoopExec, to stop the recursion after specific amount of rows + * is generated. + * Note here: limit can be applied in the main query calling the recursive CTE, and not + * inside the recursive term of recursive CTE. + */ +case class UnionLoopExec( + loopId: Long, + @transient anchor: LogicalPlan, + @transient recursion: LogicalPlan, + override val output: Seq[Attribute], + limit: Option[Int] = None) extends LeafExecNode { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(anchor, recursion) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) Review Comment: is it possible to include the number of recursive loops in here as well? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala: ########## @@ -714,6 +717,133 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { copy(children = newChildren) } +/** + * The physical node for recursion. Currently only UNION ALL case is supported. + * In the first iteration, anchor term is executed. + * Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the + * previous iteration, and such plan is executed. + * After every iteration, the dataframe is repartitioned. + * The recursion stops when the generated dataframe is empty, or either the limit or + * the specified maximum depth from the config is reached. + * + * @param loopId The id of the loop. + * @param anchor The logical plan of the initial element of the loop. + * @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node. + * @param output The output attributes of this loop. + * @param limit In case we have a plan with the limit node, it is pushed down to UnionLoop and then + * transferred to UnionLoopExec, to stop the recursion after specific amount of rows + * is generated. + * Note here: limit can be applied in the main query calling the recursive CTE, and not + * inside the recursive term of recursive CTE. + */ +case class UnionLoopExec( + loopId: Long, + @transient anchor: LogicalPlan, + @transient recursion: LogicalPlan, + override val output: Seq[Attribute], + limit: Option[Int] = None) extends LeafExecNode { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(anchor, recursion) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + /** + * This function executes the plan (optionally with appended limit node) and caches the result, + * with the caching mode specified in config. + */ + private def executeAndCacheAndCount( + plan: LogicalPlan, currentLimit: Int) = { + // In case limit is defined, we create a (global) limit node above the plan and execute + // the newly created plan. + // Note here: global limit requires coordination (shuffle) between partitions. + val planOrLimitedPlan = if (limit.isDefined) { + Limit(Literal(currentLimit), plan) + } else { + plan + } + val df = Dataset.ofRows(session, planOrLimitedPlan) + val cachedDF = df.repartition() + val count = cachedDF.count() + (cachedDF, count) + } + + override protected def doExecute(): RDD[InternalRow] = { + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + val numOutputRows = longMetric("numOutputRows") + val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT) + + // currentLimit is initialized from the limit argument, and in each step it is decreased by + // the number of rows generated in that step. + // If limit is not passed down, currentLimit is set to be zero and won't be considered in the + // condition of while loop down (limit.isEmpty will be true). + var currentLimit = limit.getOrElse(0) + val unionChildren = mutable.ArrayBuffer.empty[LogicalRDD] + + var (prevDF, prevCount) = executeAndCacheAndCount(anchor, currentLimit) + + var currentLevel = 1 + + // Main loop for obtaining the result of the recursive query. + while (prevCount > 0 && (limit.isEmpty || currentLimit > 0)) { + + if (levelLimit != -1 && currentLevel > levelLimit) { + throw new SparkException(s"Recursion level limit ${levelLimit} reached but query has not " + Review Comment: please add an error class for this ########## sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala: ########## @@ -714,6 +717,133 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { copy(children = newChildren) } +/** + * The physical node for recursion. Currently only UNION ALL case is supported. + * In the first iteration, anchor term is executed. + * Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the + * previous iteration, and such plan is executed. + * After every iteration, the dataframe is repartitioned. + * The recursion stops when the generated dataframe is empty, or either the limit or + * the specified maximum depth from the config is reached. + * + * @param loopId The id of the loop. + * @param anchor The logical plan of the initial element of the loop. + * @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node. + * @param output The output attributes of this loop. + * @param limit In case we have a plan with the limit node, it is pushed down to UnionLoop and then Review Comment: in other words, this represents a limit on the total number of rows output by this operator, correct? if so, let's mention that first as it's the high-order bit of information here. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala: ########## @@ -714,6 +717,133 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { copy(children = newChildren) } +/** + * The physical node for recursion. Currently only UNION ALL case is supported. + * In the first iteration, anchor term is executed. + * Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the + * previous iteration, and such plan is executed. + * After every iteration, the dataframe is repartitioned. + * The recursion stops when the generated dataframe is empty, or either the limit or + * the specified maximum depth from the config is reached. + * + * @param loopId The id of the loop. + * @param anchor The logical plan of the initial element of the loop. + * @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node. Review Comment: also reiterate here what the `[[UnionLoopRef]]` represents? This could be improved with an example above the @params with a short SQL snippet and corresponding logical plan. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala: ########## @@ -714,6 +717,133 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { copy(children = newChildren) } +/** + * The physical node for recursion. Currently only UNION ALL case is supported. + * In the first iteration, anchor term is executed. + * Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the + * previous iteration, and such plan is executed. + * After every iteration, the dataframe is repartitioned. + * The recursion stops when the generated dataframe is empty, or either the limit or + * the specified maximum depth from the config is reached. + * + * @param loopId The id of the loop. + * @param anchor The logical plan of the initial element of the loop. + * @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node. + * @param output The output attributes of this loop. + * @param limit In case we have a plan with the limit node, it is pushed down to UnionLoop and then + * transferred to UnionLoopExec, to stop the recursion after specific amount of rows + * is generated. + * Note here: limit can be applied in the main query calling the recursive CTE, and not + * inside the recursive term of recursive CTE. + */ +case class UnionLoopExec( + loopId: Long, + @transient anchor: LogicalPlan, + @transient recursion: LogicalPlan, + override val output: Seq[Attribute], + limit: Option[Int] = None) extends LeafExecNode { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(anchor, recursion) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + /** + * This function executes the plan (optionally with appended limit node) and caches the result, + * with the caching mode specified in config. + */ + private def executeAndCacheAndCount( + plan: LogicalPlan, currentLimit: Int) = { + // In case limit is defined, we create a (global) limit node above the plan and execute + // the newly created plan. + // Note here: global limit requires coordination (shuffle) between partitions. + val planOrLimitedPlan = if (limit.isDefined) { + Limit(Literal(currentLimit), plan) + } else { + plan + } + val df = Dataset.ofRows(session, planOrLimitedPlan) + val cachedDF = df.repartition() + val count = cachedDF.count() + (cachedDF, count) + } + + override protected def doExecute(): RDD[InternalRow] = { Review Comment: could we please have a descriptive comment here that says what the general steps are of the execution implementation? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org