nemanjapetr-db commented on code in PR #49351:
URL: https://github.com/apache/spark/pull/49351#discussion_r1903977030


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala:
##########
@@ -539,48 +582,58 @@ case class Union(
     children.length > 1 && !(byName || allowMissingCol) && childrenResolved && 
allChildrenCompatible
   }
 
-  private lazy val lazyOutput: Seq[Attribute] = computeOutput()
-
-  private def computeOutput(): Seq[Attribute] = 
Union.mergeChildOutputs(children.map(_.output))
-
-  /**
-   * Maps the constraints containing a given (original) sequence of attributes 
to those with a
-   * given (reference) sequence of attributes. Given the nature of union, we 
expect that the
-   * mapping between the original and reference sequences are symmetric.
-   */
-  private def rewriteConstraints(
-      reference: Seq[Attribute],
-      original: Seq[Attribute],
-      constraints: ExpressionSet): ExpressionSet = {
-    require(reference.size == original.size)
-    val attributeRewrites = AttributeMap(original.zip(reference))
-    constraints.map(_ transform {
-      case a: Attribute => attributeRewrites(a)
-    })
-  }
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[LogicalPlan]): Union =
+    copy(children = newChildren)
+}
 
-  private def merge(a: ExpressionSet, b: ExpressionSet): ExpressionSet = {
-    val common = a.intersect(b)
-    // The constraint with only one reference could be easily inferred as 
predicate
-    // Grouping the constraints by it's references so we can combine the 
constraints with same
-    // reference together
-    val othera = a.diff(common).filter(_.references.size == 
1).groupBy(_.references.head)
-    val otherb = b.diff(common).filter(_.references.size == 
1).groupBy(_.references.head)
-    // loose the constraints by: A1 && B1 || A2 && B2  ->  (A1 || A2) && (B1 
|| B2)
-    val others = (othera.keySet intersect otherb.keySet).map { attr =>
-      Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
-    }
-    common ++ others
-  }
+/**
+ * The logical node for recursion, that contains a initial (anchor) and a 
recursion describing term,
+ * that contains an [[UnionLoopRef]] node.
+ * The node is very similar to [[Union]] because the initial and "generated" 
children are union-ed
+ * and it is also similar to a loop because the recursion continues until the 
last generated child
+ * is not empty.
+ *
+ * @param id The id of the loop, inherited from [[CTERelationDef]]
+ * @param anchor The plan of the initial element of the loop.
+ * @param recursion The plan that describes the recursion with an 
[[UnionLoopRef]] node.
+ * @param limit An optional limit that can be pushed down to the node to stop 
the loop earlier.
+ */
+case class UnionLoop(
+                      id: Long,

Review Comment:
   Done.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala:
##########
@@ -462,6 +462,59 @@ object Union {
   }
 }
 
+abstract class UnionBase extends LogicalPlan {
+  // updating nullability to make all the children consistent
+  override def output: Seq[Attribute] = {
+    if (conf.getConf(SQLConf.LAZY_SET_OPERATOR_OUTPUT)) {
+      lazyOutput
+    } else {
+      computeOutput()
+    }
+  }
+
+  override def metadataOutput: Seq[Attribute] = Nil
+
+  private lazy val lazyOutput: Seq[Attribute] = computeOutput()
+
+  private def computeOutput(): Seq[Attribute] = 
Union.mergeChildOutputs(children.map(_.output))
+
+  /**
+   * Maps the constraints containing a given (original) sequence of attributes 
to those with a
+   * given (reference) sequence of attributes. Given the nature of union, we 
expect that the
+   * mapping between the original and reference sequences are symmetric.
+   */
+  private def rewriteConstraints(
+                                  reference: Seq[Attribute],

Review Comment:
   Done.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala:
##########
@@ -539,48 +582,58 @@ case class Union(
     children.length > 1 && !(byName || allowMissingCol) && childrenResolved && 
allChildrenCompatible
   }
 
-  private lazy val lazyOutput: Seq[Attribute] = computeOutput()
-
-  private def computeOutput(): Seq[Attribute] = 
Union.mergeChildOutputs(children.map(_.output))
-
-  /**
-   * Maps the constraints containing a given (original) sequence of attributes 
to those with a
-   * given (reference) sequence of attributes. Given the nature of union, we 
expect that the
-   * mapping between the original and reference sequences are symmetric.
-   */
-  private def rewriteConstraints(
-      reference: Seq[Attribute],
-      original: Seq[Attribute],
-      constraints: ExpressionSet): ExpressionSet = {
-    require(reference.size == original.size)
-    val attributeRewrites = AttributeMap(original.zip(reference))
-    constraints.map(_ transform {
-      case a: Attribute => attributeRewrites(a)
-    })
-  }
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[LogicalPlan]): Union =
+    copy(children = newChildren)
+}
 
-  private def merge(a: ExpressionSet, b: ExpressionSet): ExpressionSet = {
-    val common = a.intersect(b)
-    // The constraint with only one reference could be easily inferred as 
predicate
-    // Grouping the constraints by it's references so we can combine the 
constraints with same
-    // reference together
-    val othera = a.diff(common).filter(_.references.size == 
1).groupBy(_.references.head)
-    val otherb = b.diff(common).filter(_.references.size == 
1).groupBy(_.references.head)
-    // loose the constraints by: A1 && B1 || A2 && B2  ->  (A1 || A2) && (B1 
|| B2)
-    val others = (othera.keySet intersect otherb.keySet).map { attr =>
-      Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
-    }
-    common ++ others
-  }
+/**
+ * The logical node for recursion, that contains a initial (anchor) and a 
recursion describing term,
+ * that contains an [[UnionLoopRef]] node.
+ * The node is very similar to [[Union]] because the initial and "generated" 
children are union-ed
+ * and it is also similar to a loop because the recursion continues until the 
last generated child
+ * is not empty.
+ *
+ * @param id The id of the loop, inherited from [[CTERelationDef]]
+ * @param anchor The plan of the initial element of the loop.
+ * @param recursion The plan that describes the recursion with an 
[[UnionLoopRef]] node.
+ * @param limit An optional limit that can be pushed down to the node to stop 
the loop earlier.
+ */
+case class UnionLoop(
+                      id: Long,
+                      anchor: LogicalPlan,
+                      recursion: LogicalPlan,
+                      limit: Option[Int] = None) extends UnionBase {
+  override def children: Seq[LogicalPlan] = Seq(anchor, recursion)
+
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[LogicalPlan]): UnionLoop =
+    copy(anchor = newChildren(0), recursion = newChildren(1))
+}
 
-  override protected lazy val validConstraints: ExpressionSet = {
-    children
-      .map(child => rewriteConstraints(children.head.output, child.output, 
child.constraints))
-      .reduce(merge(_, _))
+/**
+ * The recursive reference in the recursive term of an [[UnionLoop]] node.
+ *
+ * @param loopId The id of the loop, inherited from [[CTERelationRef]]
+ * @param output The output attributes of this recursive reference.
+ * @param accumulated If false the the reference stands for the result of the 
previous iteration.
+ *                    If it is true then then it stands for the union of all 
previous iteration
+ *                    results.
+ */
+case class UnionLoopRef(

Review Comment:
   Will be a part of 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/InsertLoops.scala
 that'll come in a follow up PR and which will be invoked from 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
   
   Wanted to keep this PR lean.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala:
##########
@@ -37,21 +38,89 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
     }
   }
 
+  private def updateRecursiveAnchor(cteDef: CTERelationDef): CTERelationDef = {
+    cteDef.child match {
+      case SubqueryAlias(_, u: Union) =>
+        if (u.children.head.resolved) {
+          cteDef.copy(recursionAnchor = Some(u.children.head))
+        } else {
+          cteDef
+        }
+      case SubqueryAlias(_, d @ Distinct(u: Union)) =>
+        if (u.children.head.resolved) {
+          cteDef.copy(recursionAnchor = Some(d.copy(child = u.children.head)))
+        } else {
+          cteDef
+        }
+      case SubqueryAlias(_, a @ UnresolvedSubqueryColumnAliases(_, u: Union)) 
=>
+        if (u.children.head.resolved) {
+          cteDef.copy(recursionAnchor = Some(a.copy(child = u.children.head)))
+        } else {
+          cteDef
+        }
+      case SubqueryAlias(_, a @ UnresolvedSubqueryColumnAliases(_, d @ 
Distinct(u: Union))) =>
+        if (u.children.head.resolved) {
+          cteDef.copy(recursionAnchor = Some(a.copy(child = d.copy(child = 
u.children.head))))
+        } else {
+          cteDef
+        }
+      case _ =>
+        cteDef.failAnalysis(
+          errorClass = "INVALID_RECURSIVE_CTE",
+          messageParameters = Map.empty)
+    }
+  }
+
   private def resolveWithCTE(
       plan: LogicalPlan,
       cteDefMap: mutable.HashMap[Long, CTERelationDef]): LogicalPlan = {
     plan.resolveOperatorsDownWithPruning(_.containsAllPatterns(CTE)) {
       case w @ WithCTE(_, cteDefs) =>
-        cteDefs.foreach { cteDef =>
-          if (cteDef.resolved) {
-            cteDefMap.put(cteDef.id, cteDef)
+        val newCTEDefs = cteDefs.map { cteDef =>
+          // If a recursive CTE definition is not yet resolved then extract 
the anchor term to the
+          // definition, but if it is resolved then the extracted anchor term 
is no longer needed
+          // and can be removed.
+          val newCTEDef = if (cteDef.recursive) {
+            if (!cteDef.resolved) {
+              if (cteDef.recursionAnchor.isEmpty) {
+                updateRecursiveAnchor(cteDef)
+              } else {
+                cteDef
+              }
+            } else {
+              if (cteDef.recursionAnchor.nonEmpty) {

Review Comment:
   Why do you believe it is non-recursive, it is within if (cteDEf.recursive) 
block?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala:
##########
@@ -539,48 +582,58 @@ case class Union(
     children.length > 1 && !(byName || allowMissingCol) && childrenResolved && 
allChildrenCompatible
   }
 
-  private lazy val lazyOutput: Seq[Attribute] = computeOutput()
-
-  private def computeOutput(): Seq[Attribute] = 
Union.mergeChildOutputs(children.map(_.output))
-
-  /**
-   * Maps the constraints containing a given (original) sequence of attributes 
to those with a
-   * given (reference) sequence of attributes. Given the nature of union, we 
expect that the
-   * mapping between the original and reference sequences are symmetric.
-   */
-  private def rewriteConstraints(
-      reference: Seq[Attribute],
-      original: Seq[Attribute],
-      constraints: ExpressionSet): ExpressionSet = {
-    require(reference.size == original.size)
-    val attributeRewrites = AttributeMap(original.zip(reference))
-    constraints.map(_ transform {
-      case a: Attribute => attributeRewrites(a)
-    })
-  }
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[LogicalPlan]): Union =
+    copy(children = newChildren)
+}
 
-  private def merge(a: ExpressionSet, b: ExpressionSet): ExpressionSet = {
-    val common = a.intersect(b)
-    // The constraint with only one reference could be easily inferred as 
predicate
-    // Grouping the constraints by it's references so we can combine the 
constraints with same
-    // reference together
-    val othera = a.diff(common).filter(_.references.size == 
1).groupBy(_.references.head)
-    val otherb = b.diff(common).filter(_.references.size == 
1).groupBy(_.references.head)
-    // loose the constraints by: A1 && B1 || A2 && B2  ->  (A1 || A2) && (B1 
|| B2)
-    val others = (othera.keySet intersect otherb.keySet).map { attr =>
-      Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
-    }
-    common ++ others
-  }
+/**
+ * The logical node for recursion, that contains a initial (anchor) and a 
recursion describing term,
+ * that contains an [[UnionLoopRef]] node.
+ * The node is very similar to [[Union]] because the initial and "generated" 
children are union-ed
+ * and it is also similar to a loop because the recursion continues until the 
last generated child
+ * is not empty.
+ *
+ * @param id The id of the loop, inherited from [[CTERelationDef]]
+ * @param anchor The plan of the initial element of the loop.
+ * @param recursion The plan that describes the recursion with an 
[[UnionLoopRef]] node.
+ * @param limit An optional limit that can be pushed down to the node to stop 
the loop earlier.
+ */
+case class UnionLoop(
+                      id: Long,
+                      anchor: LogicalPlan,
+                      recursion: LogicalPlan,
+                      limit: Option[Int] = None) extends UnionBase {
+  override def children: Seq[LogicalPlan] = Seq(anchor, recursion)
+
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[LogicalPlan]): UnionLoop =
+    copy(anchor = newChildren(0), recursion = newChildren(1))
+}
 
-  override protected lazy val validConstraints: ExpressionSet = {
-    children
-      .map(child => rewriteConstraints(children.head.output, child.output, 
child.constraints))
-      .reduce(merge(_, _))
+/**
+ * The recursive reference in the recursive term of an [[UnionLoop]] node.
+ *
+ * @param loopId The id of the loop, inherited from [[CTERelationRef]]
+ * @param output The output attributes of this recursive reference.
+ * @param accumulated If false the the reference stands for the result of the 
previous iteration.
+ *                    If it is true then then it stands for the union of all 
previous iteration
+ *                    results.
+ */
+case class UnionLoopRef(
+                         loopId: Long,

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to