dtenedor commented on code in PR #49518: URL: https://github.com/apache/spark/pull/49518#discussion_r1920578691
########## common/utils/src/main/resources/error/error-conditions.json: ########## @@ -3099,6 +3099,29 @@ ], "sqlState" : "42602" }, + "INVALID_RECURSIVE_REFERENCE" : { + "message" : [ + "Invalid recursive reference found." Review Comment: this doesn't mention it's about the WITH clause (or similar dataframe API), can we mention these specifically here so the user knows what part of the query this is referring to ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala: ########## @@ -1043,6 +1044,75 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB if (Utils.isTesting) scrubOutIds(result) else result } + /** + * Recursion, according to SQL standard, comes with several limitations: + * 1. Recursive term can contain one recursive reference only. + * 2. Recursive reference can't be used in some kinds of joins and aggregations. + * This rule checks that these restrictions are not violated. + */ + private def checkRecursion( + plan: LogicalPlan, + references: mutable.Map[Long, (Int, Seq[DataType])] = mutable.Map.empty): Unit = { + plan match { + // The map is filled with UnionLoop id as key and 0 (number of Ref occasions) and datatype + // as value + case UnionLoop(id, anchor, recursion, _) => + checkRecursion(anchor, references) + checkRecursion(recursion, references += id -> (0, anchor.output.map(_.dataType))) + references -= id + case r @ UnionLoopRef(loopId, output, false) => + // If we encounter a recursive reference, it has to be present in the map + if (!references.contains(loopId)) { + r.failAnalysis( + errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE", + messageParameters = Map.empty + ) + } + val (count, dataType) = references(loopId) + if (count > 0) { + r.failAnalysis( + errorClass = "INVALID_RECURSIVE_REFERENCE.NUMBER", + messageParameters = Map.empty + ) + } + val originalDataType = r.output.map(_.dataType) + if (!originalDataType.zip(dataType).forall { + case (odt, dt) => DataType.equalsStructurally(odt, dt, true) Review Comment: can you add an implementation comment for this check, it seems non trivial. Why are we using this type of check for the data types? ########## common/utils/src/main/resources/error/error-conditions.json: ########## @@ -3099,6 +3099,29 @@ ], "sqlState" : "42602" }, + "INVALID_RECURSIVE_REFERENCE" : { + "message" : [ + "Invalid recursive reference found." + ], + "subClass" : { + "DATA_TYPE" : { + "message" : [ + "The data type of recursive references cannot change during resolution. Originally it was <fromDataType> but after resolution is <toDataType>." Review Comment: can you also mention what the user should do to modify the query to make it succeed upon a subsequent attempt? Same below. ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala: ########## @@ -1043,6 +1044,75 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB if (Utils.isTesting) scrubOutIds(result) else result } + /** + * Recursion, according to SQL standard, comes with several limitations: + * 1. Recursive term can contain one recursive reference only. + * 2. Recursive reference can't be used in some kinds of joins and aggregations. + * This rule checks that these restrictions are not violated. + */ + private def checkRecursion( + plan: LogicalPlan, + references: mutable.Map[Long, (Int, Seq[DataType])] = mutable.Map.empty): Unit = { + plan match { + // The map is filled with UnionLoop id as key and 0 (number of Ref occasions) and datatype + // as value + case UnionLoop(id, anchor, recursion, _) => + checkRecursion(anchor, references) + checkRecursion(recursion, references += id -> (0, anchor.output.map(_.dataType))) + references -= id + case r @ UnionLoopRef(loopId, output, false) => + // If we encounter a recursive reference, it has to be present in the map + if (!references.contains(loopId)) { + r.failAnalysis( + errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE", + messageParameters = Map.empty + ) + } + val (count, dataType) = references(loopId) + if (count > 0) { + r.failAnalysis( + errorClass = "INVALID_RECURSIVE_REFERENCE.NUMBER", + messageParameters = Map.empty + ) + } + val originalDataType = r.output.map(_.dataType) + if (!originalDataType.zip(dataType).forall { + case (odt, dt) => DataType.equalsStructurally(odt, dt, true) + }) { + r.failAnalysis( + errorClass = "INVALID_RECURSIVE_REFERENCE.DATA_TYPE", + messageParameters = Map( + "fromDataType" -> originalDataType.map(toSQLType).mkString(", "), + "toDataType" -> dataType.map(toSQLType).mkString(", ") + ) + ) + } + references(loopId) = (count + 1, dataType) + case Join(left, right, Inner, _, _) => + checkRecursion(left, references) Review Comment: this algorithm is going to create a lot of stack frames. Could you please convert it to a loop instead, starting with the initial operator to check, and using a queue to add new nodes to check and popping them off after checking them. In this way, we can improve performance and memory usage. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org