comphead commented on code in PR #4015:
URL: https://github.com/apache/datafusion-comet/pull/4015#discussion_r3261939377
##########
spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala:
##########
@@ -778,4 +794,127 @@ case class CometExecRule(session: SparkSession)
}
}
+ /**
+ * Walk the plan to find Final-mode aggregates that cannot be converted to
Comet. For each such
+ * Final, if the aggregate functions have incompatible intermediate buffer
formats, tag the
+ * corresponding Partial-mode aggregate so it will also be skipped during
conversion.
+ *
+ * This prevents the crash described in issue #1389 where a Comet Partial
produces intermediate
+ * data in a format that the Spark Final cannot interpret.
+ */
+ private def tagUnsafePartialAggregates(plan: SparkPlan): Unit = {
+ plan.foreach {
+ case agg: BaseAggregateExec =>
+ // Only consider single-mode Final aggregates. Multi-mode Finals come
from Spark's
+ // distinct-aggregate rewrite, where the Comet partial (if any) feeds
into a Spark
+ // PartialMerge rather than directly into a Final, which is a
different code path
+ // than the Comet-Partial → Spark-Final crash scenario from issue
#1389.
+ val modes = agg.aggregateExpressions.map(_.mode).distinct
+ if (modes == Seq(Final) &&
+
!QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions) &&
+ !canAggregateBeConverted(agg, Final)) {
+ findPartialAggInPlan(agg.child).foreach { partial =>
+ // Only tag if the Partial would otherwise have been converted. If
the Partial
+ // itself cannot be converted (e.g. the aggregate function is
incompatible for the
+ // input type), there is no buffer-format mismatch to guard
against, and tagging
+ // would mask the natural, more specific fallback reason.
+ if (canAggregateBeConverted(partial, Partial)) {
+ partial.setTagValue(
+ CometExecRule.COMET_UNSAFE_PARTIAL,
+ "Partial aggregate disabled: corresponding final aggregate " +
+ "cannot be converted to Comet and intermediate buffer
formats are incompatible")
+ }
+ }
+ }
+ case _ =>
+ }
+ }
+
+ /**
+ * Conservative check for whether an aggregate could be converted to Comet.
Checks operator
+ * enablement, grouping expressions, aggregate expressions, and result
expressions.
+ * Intentionally skips the sparkFinalMode / child-native checks since those
depend on
+ * transformation state.
+ *
+ * WARNING: this intentionally mirrors the predicate checks in
`CometBaseAggregate.doConvert`
+ * (operators.scala). Any change to the convertibility rules there must be
reflected here or
+ * this tagging pass will drift and either crash (missed tag) or
over-disable (spurious tag). A
+ * shared predicate helper would be preferable.
+ */
+ private def canAggregateBeConverted(
+ agg: BaseAggregateExec,
+ expectedMode: AggregateMode): Boolean = {
+ val handler = allExecs.get(agg.getClass)
+ if (handler.isEmpty) return false
+ val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]]
+ if (!isOperatorEnabled(serde, agg.asInstanceOf[SparkPlan])) return false
+
+ // ObjectHashAggregate has an extra shuffle-enabled guard in its convert
method
+ agg match {
+ case _: ObjectHashAggregateExec if !isCometShuffleEnabled(agg.conf) =>
return false
+ case _ =>
+ }
+
+ val aggregateExpressions = agg.aggregateExpressions
+ val groupingExpressions = agg.groupingExpressions
+
+ if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) return
false
+
+ if (groupingExpressions.exists(e =>
QueryPlanSerde.containsMapType(e.dataType))) return false
+
+ if (!groupingExpressions.forall(e =>
+ QueryPlanSerde.exprToProto(e, agg.child.output).isDefined)) {
+ return false
+ }
+
+ if (aggregateExpressions.isEmpty) {
+ // Result expressions always checked when there are no aggregate
expressions
+ val attributes =
+ groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
+ return agg.resultExpressions.forall(e =>
+ QueryPlanSerde.exprToProto(e, attributes).isDefined)
+ }
+
+ val modes = aggregateExpressions.map(_.mode).distinct
+ if (modes.size != 1 || modes.head != expectedMode) return false
+
+ // In Final mode, exprToProto resolves against the child's output; in
Partial/non-Final mode
+ // it must bind to input attributes. This mirrors the `binding`
calculation in
+ // `CometBaseAggregate.doConvert`.
+ val binding = expectedMode != Final
+ if (!aggregateExpressions.forall(e =>
+ QueryPlanSerde.aggExprToProto(e, agg.child.output, binding,
agg.conf).isDefined)) {
+ return false
+ }
+
+ // doConvert only checks resultExpressions in Final mode when aggregate
expressions exist
+ // (Partial emits the buffer directly). Mirror that here to avoid false
negatives.
+ if (expectedMode == Final) {
+ val attributes =
+ groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
+ agg.resultExpressions.forall(e => QueryPlanSerde.exprToProto(e,
attributes).isDefined)
+ } else {
+ true
+ }
+ }
+
+ /**
+ * Look for a Partial-mode aggregate that feeds directly into the given plan
(the child of a
+ * Final). Walks through exchanges and AQE stages only, stopping at anything
else including
+ * other aggregate stages. This avoids tagging unrelated Partials found
deeper in the plan (e.g.
+ * the non-distinct Partial in a distinct-aggregate rewrite, which is
separated from the Final
+ * by intermediate PartialMerge stages). Requires
`aggregateExpressions.nonEmpty` so that
+ * group-by-only dedup stages are not mistaken for the partial we want to
tag.
+ */
+ private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec]
= plan match {
+ case agg: BaseAggregateExec
+ if agg.aggregateExpressions.nonEmpty &&
+ agg.aggregateExpressions.forall(e => e.mode == Partial) =>
+ Some(agg)
+ case a: AQEShuffleReadExec => findPartialAggInPlan(a.child)
+ case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan)
+ case e: ShuffleExchangeExec => findPartialAggInPlan(e.child)
+ case _ => None
Review Comment:
does it make sense to log unsupported exec?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]