helioshe4 commented on code in PR #54297:
URL: https://github.com/apache/spark/pull/54297#discussion_r2819796175
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
false
}
+ /**
+ * Determines whether the order mismatch between [[child]] and
[[orderExpressions]] is due to
+ * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+ *
+ * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is
applied to the
+ * child expression. The DISTINCT rewrite uses GROUP BY on the cast result,
which can produce
+ * incorrect deduplication for types where equal values cast to different
strings
+ * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to
different strings).
+ *
+ * Safety is determined by both the source type (via
[[isCastSafeForDistinct]]) and the target
+ * type's collation (via [[isCastTargetSafeForDistinct]]).
+ *
+ * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+ * `Some(Left((inputType, castType)))` if the cast is unsafe,
carrying the source and
+ * target types for use in the error message,
+ * `None` if the mismatch is not due to a cast at all
+ */
+ def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {
+ if (orderExpressions.size != 1) return None
+ child match {
+ case Cast(castChild, castType, _, _)
+ if orderExpressions.head.child.semanticEquals(castChild) =>
+ if (isCastSafeForDistinct(castChild.dataType) &&
isCastTargetSafeForDistinct(castType)) {
+ Some(Right(()))
+ } else {
+ Some(Left((castChild.dataType, castType)))
+ }
+ case _ => None
+ }
+ }
+
+ /**
+ * Checks whether a source type preserves equality semantics after casting
to STRING/BINARY.
+ *
+ * A type is safe if equal values always produce equal string
representations and different
+ * string representations always imply different values. Types like
Float/Double are unsafe
+ * because IEEE 754 negative zero (-0.0) and positive zero (0.0) are equal
but produce
+ * different string representations.
+ *
+ * @param dt the source [[DataType]] before casting
+ * @return true if the cast preserves equality semantics for DISTINCT
deduplication
+ * @see [[orderMismatchCastSafety]]
+ */
+ private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
+ case _: IntegerType | LongType | ShortType | ByteType => true
+ case _: DecimalType => true
+ case _: DateType | TimestampType | TimestampNTZType => true
+ case _: TimeType => true
+ case _: CalendarIntervalType => true
+ case _: YearMonthIntervalType => true
+ case _: DayTimeIntervalType => true
+ case BooleanType => true
+ case BinaryType => true
+ case st: StringType if st.supportsBinaryEquality => true
+ case _: DoubleType | FloatType => false
+ case _ => false
+ }
+
+ /**
+ * Checks whether the cast target type preserves equality semantics for
DISTINCT deduplication.
+ *
+ * A non-binary-equality collation on the target [[StringType]] can cause
different source values
+ * to become equal after casting (e.g., binary values 0x414243 ("ABC") and
0x616263 ("abc") are
+ * different, but equal under UTF8_LCASE collation after casting to string).
+ *
+ * @param dt the target [[DataType]] of the cast
+ * @return true if the target type's equality semantics are safe for
DISTINCT deduplication
+ * @see [[orderMismatchCastSafety]]
+ */
+ private def isCastTargetSafeForDistinct(dt: DataType): Boolean = dt match {
+ case st: StringType => st.supportsBinaryEquality
Review Comment:
Yes `UTF8_binary` is the default (and only possible) collation for implicit
casts.
in the `implicitCast` function in `TypeCoercion.scala`:
https://github.com/apache/spark/blob/0ab410731e40ecb83a1c51eedfa905a122bd2f45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala#L232-L236
(Note the case on L234 is matched instead of L233 since ListAgg defines its
inputType as `StringTypeWithCollation` which inherits from `AbstractStringType`
And `st.defaultConcreteType` is a `StringType` which has collation
`UTF8_BINARY_COLLATION_ID`:
https://github.com/apache/spark/blob/0ab410731e40ecb83a1c51eedfa905a122bd2f45/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala#L112-L113
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]