github-actions[bot] commented on code in PR #63690:
URL: https://github.com/apache/doris/pull/63690#discussion_r3435196824


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java:
##########
@@ -70,77 +84,204 @@
  *         ->T2(D)
  */
 public class EagerAggRewriter extends DefaultPlanRewriter<PushDownAggContext> {
+    public static final int BIG_JOIN_BUILD_SIZE = 400_000;
     private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000;
     private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000;
     private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100;
+    private static final String JOIN_CNT = "joinCnt";
     private final StatsDerive derive = new StatsDerive(false);
 
     @Override
     public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> 
join, PushDownAggContext context) {
-        boolean toLeft = false;
-        boolean toRight = false;
-        boolean pushHere = false;
-        if (join.getJoinType().isAsofJoin()) {
-            // do nothing for asof join
-            return join;
+        Pair<Boolean, Boolean> pushSide = decideJoinPushSide(join, context);
+        boolean toLeft = pushSide.first;
+        boolean toRight = pushSide.second;
+        if (!toLeft && !toRight) {
+            if (SessionVariable.isEagerAggregationOnJoin()) {
+                return genAggregate(join, context);
+            } else {
+                return join;
+            }
         }
-        if (context.getAggFunctions().isEmpty()) {
-            // select t1.v from t1 join t2 on t1.id = t2.id group by t1.v, t2.v
-            // if no agg function, try to push agg to the child which contains 
all group keys
-            // TODO: consider t1.rows/(t1.id, t1.v).ndv and t2.rows/(t2.id, 
t2.v).ndv to determine push target
-            if 
(join.left().getOutputSet().containsAll(context.getGroupKeys())) {
-                toLeft = true;
-            } else if 
(join.right().getOutputSet().containsAll(context.getGroupKeys())) {
-                toRight = true;
+
+        // construct left and right group by keys
+        List<SlotReference> leftChildGroupByKeys = new ArrayList<>();
+        List<SlotReference> rightChildGroupByKeys = new ArrayList<>();
+        if (toLeft) {
+            fillGroupByKeys(join, join.left(), context, leftChildGroupByKeys);
+        }
+        if (toRight) {
+            fillGroupByKeys(join, join.right(), context, 
rightChildGroupByKeys);
+        }
+        // construct left and right aggFuncs and aliasMap
+        List<AggregateFunction> leftFuncs = new ArrayList<>();
+        List<AggregateFunction> rightFuncs = new ArrayList<>();
+        Map<AggregateFunction, Alias> leftAliasMap = new IdentityHashMap<>();
+        Map<AggregateFunction, Alias> rightAliasMap = new IdentityHashMap<>();
+        for (AggregateFunction f : context.getAggFunctions()) {
+            Set<Slot> inputs = f.getInputSlots();
+            Alias a = context.getAliasMap().get(f);
+            if (inputs.isEmpty()) {
+                if (join.getJoinType().isRightSemiOrAntiJoin()) {
+                    rightFuncs.add(f);
+                    rightAliasMap.put(f, a);
+                } else {
+                    leftFuncs.add(f);
+                    leftAliasMap.put(f, a);
+                }
+                continue;
+            }
+            if (join.left().getOutputSet().containsAll(inputs)) {
+                leftFuncs.add(f);
+                leftAliasMap.put(f, a);
+            } else if (join.right().getOutputSet().containsAll(inputs)) {
+                rightFuncs.add(f);
+                rightAliasMap.put(f, a);
             } else {
-                pushHere = true;
+                return join;
             }
+        }
+
+        boolean passThroughBigJoin = isPassThroughBigJoin(join, context);
+        boolean leftNeedOutputCount = needOutputCountForJoinChild(join, 
toLeft, toRight,
+                context.needOutputCount(), rightFuncs);
+        boolean rightNeedOutputCount = needOutputCountForJoinChild(join, 
toRight, toLeft,
+                context.needOutputCount(), leftFuncs);
+        Optional<PushDownAggContext> leftChildContext = toLeft ? 
Optional.of(context.forOneBranch(leftFuncs,
+                leftAliasMap, leftChildGroupByKeys, passThroughBigJoin, 
leftNeedOutputCount)) : Optional.empty();
+        Optional<PushDownAggContext> rightChildContext = toRight ? 
Optional.of(context.forOneBranch(rightFuncs,
+                rightAliasMap, rightChildGroupByKeys, passThroughBigJoin, 
rightNeedOutputCount)) : Optional.empty();
+
+        Plan newLeft = join.left();
+        Plan newRight = join.right();
+        if (leftChildContext.isPresent() && 
!leftChildContext.get().noGroupKeyAndNoAggFunc()) {
+            newLeft = join.left().accept(this, leftChildContext.get());
+        }
+        if (rightChildContext.isPresent() && 
!rightChildContext.get().noGroupKeyAndNoAggFunc()) {
+            newRight = join.right().accept(this, rightChildContext.get());
+        }
+
+        if (newLeft == join.left() && newRight == join.right()) {
+            context.getBilateralState().registerNoCountSlot(join);
+            return join;
+        }
+        Optional<Slot> leftChildCountSlot = 
context.getBilateralState().getCountSlot(newLeft);
+        Optional<Slot> rightChildCountSlot = 
context.getBilateralState().getCountSlot(newRight);
+        LogicalJoin<? extends Plan, ? extends Plan> newJoin = (LogicalJoin<? 
extends Plan, ? extends Plan>)
+                join.withChildren(newLeft, newRight);
+
+        if (leftChildCountSlot.isPresent() || rightChildCountSlot.isPresent()) 
{
+            return buildCanonicalJoinProject(newJoin, context, 
leftChildContext, rightChildContext,
+                    leftChildCountSlot, rightChildCountSlot);
+        }
+        context.getBilateralState().registerNoCountSlot(newJoin);
+        return newJoin;
+    }
+
+    private Pair<Boolean, Boolean> decideJoinPushSide(
+            LogicalJoin<? extends Plan, ? extends Plan> join, 
PushDownAggContext context) {
+        if (join.getJoinType().isAsofJoin() || join.isMarkJoin()) {
+            // do nothing for asof join and mark join
+            return Pair.of(false, false);

Review Comment:
   This only guards volatile join predicates. Volatile aggregate inputs can 
still enter the pushdown path because `getInputSlots()` is empty for `random()` 
and slot-free volatile UDFs. Reduced tree:
   
   ```text
   Aggregate(sum(random()) AS s, group by l.k)
     Join(l.k = r.k)
       Scan l
       Scan r
   ```
   
   `decideJoinPushSide()` classifies `sum(random())` as a left-side aggregate 
via the empty-input branch, then `genAggregate()` can build 
`Aggregate(sum(random()) by l.k)` under `l`. The rewritten join replicates one 
random value per left group across all matching `r` rows, while the original 
evaluates `random()` once per joined row, so the result changes. The same issue 
can appear after `createContextFromProject()` substitutes `sum(x)` where `x := 
random()`. Please bail out when any candidate aggregate function contains a 
volatile expression, including after project/union substitution.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to