mikhailnik-db commented on code in PR #54297:
URL: https://github.com/apache/spark/pull/54297#discussion_r2816415672


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
     false
   }
 
+  /**
+   * Determines whether the order mismatch between [[child]] and 
[[orderExpressions]] is due to
+   * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+   *
+   * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is 
applied to the
+   * child expression. The DISTINCT rewrite uses GROUP BY on the cast result, 
which can produce
+   * incorrect deduplication for types where equal values cast to different 
strings
+   * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to 
different strings).
+   *
+   * Safety is determined by both the source type (via 
[[isCastSafeForDistinct]]) and the target
+   * type's collation (via [[isCastTargetSafeForDistinct]]).
+   *
+   * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+   *         `Some(Left((inputType, castType)))` if the cast is unsafe, 
carrying the source and
+   *         target types for use in the error message,
+   *         `None` if the mismatch is not due to a cast at all
+   */
+  def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {

Review Comment:
   `Option[Either[(DataType, DataType), Unit]]` is a little bit awkward type. 
In Scala it would be more idiomatic to create a dedicated `sealed trait` with 3 
inheritors: `object`, `object`, `case class`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
     false
   }
 
+  /**
+   * Determines whether the order mismatch between [[child]] and 
[[orderExpressions]] is due to
+   * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+   *
+   * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is 
applied to the
+   * child expression. The DISTINCT rewrite uses GROUP BY on the cast result, 
which can produce
+   * incorrect deduplication for types where equal values cast to different 
strings
+   * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to 
different strings).
+   *
+   * Safety is determined by both the source type (via 
[[isCastSafeForDistinct]]) and the target
+   * type's collation (via [[isCastTargetSafeForDistinct]]).
+   *
+   * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+   *         `Some(Left((inputType, castType)))` if the cast is unsafe, 
carrying the source and
+   *         target types for use in the error message,
+   *         `None` if the mismatch is not due to a cast at all
+   */
+  def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {
+    if (orderExpressions.size != 1) return None
+    child match {
+      case Cast(castChild, castType, _, _)
+        if orderExpressions.head.child.semanticEquals(castChild) =>
+          if (isCastSafeForDistinct(castChild.dataType) && 
isCastTargetSafeForDistinct(castType)) {
+            Some(Right(()))
+          } else {
+            Some(Left((castChild.dataType, castType)))
+          }
+      case _ => None
+    }
+  }
+
+  /**
+   * Checks whether a source type preserves equality semantics after casting 
to STRING/BINARY.
+   *
+   * A type is safe if equal values always produce equal string 
representations and different
+   * string representations always imply different values. Types like 
Float/Double are unsafe
+   * because IEEE 754 negative zero (-0.0) and positive zero (0.0) are equal 
but produce
+   * different string representations.
+   *
+   * @param dt the source [[DataType]] before casting
+   * @return true if the cast preserves equality semantics for DISTINCT 
deduplication
+   * @see [[orderMismatchCastSafety]]
+   */
+  private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
+    case _: IntegerType | LongType | ShortType | ByteType => true
+    case _: DecimalType => true
+    case _: DateType | TimestampType | TimestampNTZType => true
+    case _: TimeType => true

Review Comment:
   Just to double check: is there a timezone stored in any types, and if yes, 
how is it represented in a string after cast?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
     false
   }
 
+  /**
+   * Determines whether the order mismatch between [[child]] and 
[[orderExpressions]] is due to
+   * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+   *
+   * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is 
applied to the
+   * child expression. The DISTINCT rewrite uses GROUP BY on the cast result, 
which can produce
+   * incorrect deduplication for types where equal values cast to different 
strings
+   * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to 
different strings).
+   *
+   * Safety is determined by both the source type (via 
[[isCastSafeForDistinct]]) and the target
+   * type's collation (via [[isCastTargetSafeForDistinct]]).
+   *
+   * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+   *         `Some(Left((inputType, castType)))` if the cast is unsafe, 
carrying the source and
+   *         target types for use in the error message,
+   *         `None` if the mismatch is not due to a cast at all
+   */
+  def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {
+    if (orderExpressions.size != 1) return None
+    child match {
+      case Cast(castChild, castType, _, _)
+        if orderExpressions.head.child.semanticEquals(castChild) =>
+          if (isCastSafeForDistinct(castChild.dataType) && 
isCastTargetSafeForDistinct(castType)) {
+            Some(Right(()))
+          } else {
+            Some(Left((castChild.dataType, castType)))
+          }
+      case _ => None
+    }
+  }
+
+  /**
+   * Checks whether a source type preserves equality semantics after casting 
to STRING/BINARY.
+   *
+   * A type is safe if equal values always produce equal string 
representations and different
+   * string representations always imply different values. Types like 
Float/Double are unsafe
+   * because IEEE 754 negative zero (-0.0) and positive zero (0.0) are equal 
but produce
+   * different string representations.
+   *
+   * @param dt the source [[DataType]] before casting
+   * @return true if the cast preserves equality semantics for DISTINCT 
deduplication
+   * @see [[orderMismatchCastSafety]]
+   */
+  private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
+    case _: IntegerType | LongType | ShortType | ByteType => true
+    case _: DecimalType => true
+    case _: DateType | TimestampType | TimestampNTZType => true
+    case _: TimeType => true
+    case _: CalendarIntervalType => true
+    case _: YearMonthIntervalType => true
+    case _: DayTimeIntervalType => true
+    case BooleanType => true
+    case BinaryType => true
+    case st: StringType if st.supportsBinaryEquality => true

Review Comment:
   >I'm taking a conservative approach where we can only explicitly cast FROM 
StringType with UTF8_binary and cast TO StringType with UTF8_binary.
   
   It's not obvious why `st.supportsBinaryEquality` is equivalent to 
`UTF8_binary`. Even if now it's true, there are no guarantees that there won't 
be other binary-compatible collations in the future. Let's validate the 
collation explicitly



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
     false
   }
 
+  /**
+   * Determines whether the order mismatch between [[child]] and 
[[orderExpressions]] is due to
+   * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+   *
+   * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is 
applied to the
+   * child expression. The DISTINCT rewrite uses GROUP BY on the cast result, 
which can produce
+   * incorrect deduplication for types where equal values cast to different 
strings
+   * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to 
different strings).
+   *
+   * Safety is determined by both the source type (via 
[[isCastSafeForDistinct]]) and the target
+   * type's collation (via [[isCastTargetSafeForDistinct]]).
+   *
+   * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+   *         `Some(Left((inputType, castType)))` if the cast is unsafe, 
carrying the source and
+   *         target types for use in the error message,
+   *         `None` if the mismatch is not due to a cast at all
+   */
+  def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {
+    if (orderExpressions.size != 1) return None
+    child match {
+      case Cast(castChild, castType, _, _)
+        if orderExpressions.head.child.semanticEquals(castChild) =>
+          if (isCastSafeForDistinct(castChild.dataType) && 
isCastTargetSafeForDistinct(castType)) {
+            Some(Right(()))
+          } else {
+            Some(Left((castChild.dataType, castType)))
+          }
+      case _ => None
+    }
+  }
+
+  /**
+   * Checks whether a source type preserves equality semantics after casting 
to STRING/BINARY.
+   *
+   * A type is safe if equal values always produce equal string 
representations and different
+   * string representations always imply different values. Types like 
Float/Double are unsafe
+   * because IEEE 754 negative zero (-0.0) and positive zero (0.0) are equal 
but produce
+   * different string representations.
+   *
+   * @param dt the source [[DataType]] before casting
+   * @return true if the cast preserves equality semantics for DISTINCT 
deduplication
+   * @see [[orderMismatchCastSafety]]
+   */
+  private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
+    case _: IntegerType | LongType | ShortType | ByteType => true
+    case _: DecimalType => true
+    case _: DateType | TimestampType | TimestampNTZType => true
+    case _: TimeType => true
+    case _: CalendarIntervalType => true
+    case _: YearMonthIntervalType => true
+    case _: DayTimeIntervalType => true
+    case BooleanType => true
+    case BinaryType => true
+    case st: StringType if st.supportsBinaryEquality => true
+    case _: DoubleType | FloatType => false
+    case _ => false
+  }
+
+  /**
+   * Checks whether the cast target type preserves equality semantics for 
DISTINCT deduplication.
+   *
+   * A non-binary-equality collation on the target [[StringType]] can cause 
different source values
+   * to become equal after casting (e.g., binary values 0x414243 ("ABC") and 
0x616263 ("abc") are
+   * different, but equal under UTF8_LCASE collation after casting to string).
+   *
+   * @param dt the target [[DataType]] of the cast
+   * @return true if the target type's equality semantics are safe for 
DISTINCT deduplication
+   * @see [[orderMismatchCastSafety]]
+   */
+  private def isCastTargetSafeForDistinct(dt: DataType): Boolean = dt match {
+    case st: StringType => st.supportsBinaryEquality

Review Comment:
   > implicit casting doesn't change collation, but we block explicit casting 
with collation.
   
   I think, at this point, we cannot say whether the child's cast was explicit 
or implicit. So, if we do this check for both, is it true that the implicit 
cast always uses `UTF8_binary` as the default collation? Because otherwise, we 
can accidentally block some implicit casts like `int -> 
string(UTF8_LCASE_COLLATION_ID)`



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -450,8 +450,17 @@ trait CheckAnalysis extends LookupCatalog with 
QueryErrorsBase with PlanToString
 
           case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
             if agg.isDistinct && listAgg.needSaveOrderValue =>
-            throw 
QueryCompilationErrors.functionAndOrderExpressionMismatchError(
-              listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
+              // Allow when the mismatch is only because child was cast
+              val mismatchDueToCast = listAgg.orderExpressions.size == 1 &&
+                (listAgg.child match {
+                  case Cast(castChild, _, _, _) =>
+                    
listAgg.orderExpressions.head.child.semanticEquals(castChild)
+                  case _ => false
+                })
+              if (!mismatchDueToCast) {
+                throw 
QueryCompilationErrors.functionAndOrderExpressionMismatchError(
+                  listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
+              }

Review Comment:
   Actually, I think it'd be even better to extract everything, including 
`listAgg.needSaveOrderValue`, into a method like 
`validateOrderingForDistinctFunction` that throws when needed.
   ```
   case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
               if agg.isDistinct =>
     listAgg.validateOrderingForDistinctFunction()
   ```
   



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -450,8 +450,17 @@ trait CheckAnalysis extends LookupCatalog with 
QueryErrorsBase with PlanToString
 
           case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
             if agg.isDistinct && listAgg.needSaveOrderValue =>
-            throw 
QueryCompilationErrors.functionAndOrderExpressionMismatchError(
-              listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
+              // Allow when the mismatch is only because child was cast
+              val mismatchDueToCast = listAgg.orderExpressions.size == 1 &&
+                (listAgg.child match {
+                  case Cast(castChild, _, _, _) =>
+                    
listAgg.orderExpressions.head.child.semanticEquals(castChild)
+                  case _ => false
+                })
+              if (!mismatchDueToCast) {
+                throw 
QueryCompilationErrors.functionAndOrderExpressionMismatchError(
+                  listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
+              }

Review Comment:
   Moreover, the logic in this pr is not trivial, so it makes sense to have a 
feature flag guarding changes in this pr. It will be convenient to put 
`if(flag)` branching inside a listagg method



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to