mikhailnik-db commented on code in PR #54297:
URL: https://github.com/apache/spark/pull/54297#discussion_r2816415672
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
false
}
+ /**
+ * Determines whether the order mismatch between [[child]] and
[[orderExpressions]] is due to
+ * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+ *
+ * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is
applied to the
+ * child expression. The DISTINCT rewrite uses GROUP BY on the cast result,
which can produce
+ * incorrect deduplication for types where equal values cast to different
strings
+ * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to
different strings).
+ *
+ * Safety is determined by both the source type (via
[[isCastSafeForDistinct]]) and the target
+ * type's collation (via [[isCastTargetSafeForDistinct]]).
+ *
+ * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+ * `Some(Left((inputType, castType)))` if the cast is unsafe,
carrying the source and
+ * target types for use in the error message,
+ * `None` if the mismatch is not due to a cast at all
+ */
+ def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {
Review Comment:
`Option[Either[(DataType, DataType), Unit]]` is a little bit awkward type.
In Scala it would be more idiomatic to create a dedicated `sealed trait` with 3
inheritors: `object`, `object`, `case class`.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
false
}
+ /**
+ * Determines whether the order mismatch between [[child]] and
[[orderExpressions]] is due to
+ * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+ *
+ * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is
applied to the
+ * child expression. The DISTINCT rewrite uses GROUP BY on the cast result,
which can produce
+ * incorrect deduplication for types where equal values cast to different
strings
+ * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to
different strings).
+ *
+ * Safety is determined by both the source type (via
[[isCastSafeForDistinct]]) and the target
+ * type's collation (via [[isCastTargetSafeForDistinct]]).
+ *
+ * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+ * `Some(Left((inputType, castType)))` if the cast is unsafe,
carrying the source and
+ * target types for use in the error message,
+ * `None` if the mismatch is not due to a cast at all
+ */
+ def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {
+ if (orderExpressions.size != 1) return None
+ child match {
+ case Cast(castChild, castType, _, _)
+ if orderExpressions.head.child.semanticEquals(castChild) =>
+ if (isCastSafeForDistinct(castChild.dataType) &&
isCastTargetSafeForDistinct(castType)) {
+ Some(Right(()))
+ } else {
+ Some(Left((castChild.dataType, castType)))
+ }
+ case _ => None
+ }
+ }
+
+ /**
+ * Checks whether a source type preserves equality semantics after casting
to STRING/BINARY.
+ *
+ * A type is safe if equal values always produce equal string
representations and different
+ * string representations always imply different values. Types like
Float/Double are unsafe
+ * because IEEE 754 negative zero (-0.0) and positive zero (0.0) are equal
but produce
+ * different string representations.
+ *
+ * @param dt the source [[DataType]] before casting
+ * @return true if the cast preserves equality semantics for DISTINCT
deduplication
+ * @see [[orderMismatchCastSafety]]
+ */
+ private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
+ case _: IntegerType | LongType | ShortType | ByteType => true
+ case _: DecimalType => true
+ case _: DateType | TimestampType | TimestampNTZType => true
+ case _: TimeType => true
Review Comment:
Just to double check: is there a timezone stored in any types, and if yes,
how is it represented in a string after cast?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
false
}
+ /**
+ * Determines whether the order mismatch between [[child]] and
[[orderExpressions]] is due to
+ * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+ *
+ * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is
applied to the
+ * child expression. The DISTINCT rewrite uses GROUP BY on the cast result,
which can produce
+ * incorrect deduplication for types where equal values cast to different
strings
+ * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to
different strings).
+ *
+ * Safety is determined by both the source type (via
[[isCastSafeForDistinct]]) and the target
+ * type's collation (via [[isCastTargetSafeForDistinct]]).
+ *
+ * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+ * `Some(Left((inputType, castType)))` if the cast is unsafe,
carrying the source and
+ * target types for use in the error message,
+ * `None` if the mismatch is not due to a cast at all
+ */
+ def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {
+ if (orderExpressions.size != 1) return None
+ child match {
+ case Cast(castChild, castType, _, _)
+ if orderExpressions.head.child.semanticEquals(castChild) =>
+ if (isCastSafeForDistinct(castChild.dataType) &&
isCastTargetSafeForDistinct(castType)) {
+ Some(Right(()))
+ } else {
+ Some(Left((castChild.dataType, castType)))
+ }
+ case _ => None
+ }
+ }
+
+ /**
+ * Checks whether a source type preserves equality semantics after casting
to STRING/BINARY.
+ *
+ * A type is safe if equal values always produce equal string
representations and different
+ * string representations always imply different values. Types like
Float/Double are unsafe
+ * because IEEE 754 negative zero (-0.0) and positive zero (0.0) are equal
but produce
+ * different string representations.
+ *
+ * @param dt the source [[DataType]] before casting
+ * @return true if the cast preserves equality semantics for DISTINCT
deduplication
+ * @see [[orderMismatchCastSafety]]
+ */
+ private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
+ case _: IntegerType | LongType | ShortType | ByteType => true
+ case _: DecimalType => true
+ case _: DateType | TimestampType | TimestampNTZType => true
+ case _: TimeType => true
+ case _: CalendarIntervalType => true
+ case _: YearMonthIntervalType => true
+ case _: DayTimeIntervalType => true
+ case BooleanType => true
+ case BinaryType => true
+ case st: StringType if st.supportsBinaryEquality => true
Review Comment:
>I'm taking a conservative approach where we can only explicitly cast FROM
StringType with UTF8_binary and cast TO StringType with UTF8_binary.
It's not obvious why `st.supportsBinaryEquality` is equivalent to
`UTF8_binary`. Even if now it's true, there are no guarantees that there won't
be other binary-compatible collations in the future. Let's validate the
collation explicitly
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +564,81 @@ case class ListAgg(
false
}
+ /**
+ * Determines whether the order mismatch between [[child]] and
[[orderExpressions]] is due to
+ * a cast, and if so, whether that cast is safe for DISTINCT deduplication.
+ *
+ * When LISTAGG(DISTINCT) is used with a non-string column, a Cast is
applied to the
+ * child expression. The DISTINCT rewrite uses GROUP BY on the cast result,
which can produce
+ * incorrect deduplication for types where equal values cast to different
strings
+ * (e.g., Float/Double where -0.0 and 0.0 are GROUP BY-equal but cast to
different strings).
+ *
+ * Safety is determined by both the source type (via
[[isCastSafeForDistinct]]) and the target
+ * type's collation (via [[isCastTargetSafeForDistinct]]).
+ *
+ * @return `Some(Right(()))` if the mismatch is due to a safe cast,
+ * `Some(Left((inputType, castType)))` if the cast is unsafe,
carrying the source and
+ * target types for use in the error message,
+ * `None` if the mismatch is not due to a cast at all
+ */
+ def orderMismatchCastSafety: Option[Either[(DataType, DataType), Unit]] = {
+ if (orderExpressions.size != 1) return None
+ child match {
+ case Cast(castChild, castType, _, _)
+ if orderExpressions.head.child.semanticEquals(castChild) =>
+ if (isCastSafeForDistinct(castChild.dataType) &&
isCastTargetSafeForDistinct(castType)) {
+ Some(Right(()))
+ } else {
+ Some(Left((castChild.dataType, castType)))
+ }
+ case _ => None
+ }
+ }
+
+ /**
+ * Checks whether a source type preserves equality semantics after casting
to STRING/BINARY.
+ *
+ * A type is safe if equal values always produce equal string
representations and different
+ * string representations always imply different values. Types like
Float/Double are unsafe
+ * because IEEE 754 negative zero (-0.0) and positive zero (0.0) are equal
but produce
+ * different string representations.
+ *
+ * @param dt the source [[DataType]] before casting
+ * @return true if the cast preserves equality semantics for DISTINCT
deduplication
+ * @see [[orderMismatchCastSafety]]
+ */
+ private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
+ case _: IntegerType | LongType | ShortType | ByteType => true
+ case _: DecimalType => true
+ case _: DateType | TimestampType | TimestampNTZType => true
+ case _: TimeType => true
+ case _: CalendarIntervalType => true
+ case _: YearMonthIntervalType => true
+ case _: DayTimeIntervalType => true
+ case BooleanType => true
+ case BinaryType => true
+ case st: StringType if st.supportsBinaryEquality => true
+ case _: DoubleType | FloatType => false
+ case _ => false
+ }
+
+ /**
+ * Checks whether the cast target type preserves equality semantics for
DISTINCT deduplication.
+ *
+ * A non-binary-equality collation on the target [[StringType]] can cause
different source values
+ * to become equal after casting (e.g., binary values 0x414243 ("ABC") and
0x616263 ("abc") are
+ * different, but equal under UTF8_LCASE collation after casting to string).
+ *
+ * @param dt the target [[DataType]] of the cast
+ * @return true if the target type's equality semantics are safe for
DISTINCT deduplication
+ * @see [[orderMismatchCastSafety]]
+ */
+ private def isCastTargetSafeForDistinct(dt: DataType): Boolean = dt match {
+ case st: StringType => st.supportsBinaryEquality
Review Comment:
> implicit casting doesn't change collation, but we block explicit casting
with collation.
I think, at this point, we cannot say whether the child's cast was explicit
or implicit. So, if we do this check for both, is it true that the implicit
cast always uses `UTF8_binary` as the default collation? Because otherwise, we
can accidentally block some implicit casts like `int ->
string(UTF8_LCASE_COLLATION_ID)`
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -450,8 +450,17 @@ trait CheckAnalysis extends LookupCatalog with
QueryErrorsBase with PlanToString
case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct && listAgg.needSaveOrderValue =>
- throw
QueryCompilationErrors.functionAndOrderExpressionMismatchError(
- listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
+ // Allow when the mismatch is only because child was cast
+ val mismatchDueToCast = listAgg.orderExpressions.size == 1 &&
+ (listAgg.child match {
+ case Cast(castChild, _, _, _) =>
+
listAgg.orderExpressions.head.child.semanticEquals(castChild)
+ case _ => false
+ })
+ if (!mismatchDueToCast) {
+ throw
QueryCompilationErrors.functionAndOrderExpressionMismatchError(
+ listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
+ }
Review Comment:
Actually, I think it'd be even better to extract everything, including
`listAgg.needSaveOrderValue`, into a method like
`validateOrderingForDistinctFunction` that throws when needed.
```
case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct =>
listAgg.validateOrderingForDistinctFunction()
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -450,8 +450,17 @@ trait CheckAnalysis extends LookupCatalog with
QueryErrorsBase with PlanToString
case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct && listAgg.needSaveOrderValue =>
- throw
QueryCompilationErrors.functionAndOrderExpressionMismatchError(
- listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
+ // Allow when the mismatch is only because child was cast
+ val mismatchDueToCast = listAgg.orderExpressions.size == 1 &&
+ (listAgg.child match {
+ case Cast(castChild, _, _, _) =>
+
listAgg.orderExpressions.head.child.semanticEquals(castChild)
+ case _ => false
+ })
+ if (!mismatchDueToCast) {
+ throw
QueryCompilationErrors.functionAndOrderExpressionMismatchError(
+ listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
+ }
Review Comment:
Moreover, the logic in this pr is not trivial, so it makes sense to have a
feature flag guarding changes in this pr. It will be convenient to put
`if(flag)` branching inside a listagg method
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]