sohardforaname commented on code in PR #17968:
URL: https://github.com/apache/doris/pull/17968#discussion_r1147082277


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunction.java:
##########
@@ -0,0 +1,371 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanUtils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * change scalar sub query containing agg to window function. such as:
+ * SELECT SUM(l_extendedprice) / 7.0 AS avg_yearly
+ *  FROM lineitem, part
+ *  WHERE p_partkey = l_partkey AND
+ *  p_brand = 'Brand#23' AND
+ *  p_container = 'MED BOX' AND
+ *  l_quantity<(SELECT 0.2*avg(l_quantity)
+ *  FROM lineitem
+ *  WHERE l_partkey = p_partkey);
+ * to:
+ * SELECT SUM(l_extendedprice) / 7.0 as avg_yearly
+ *  FROM (SELECT l_extendedprice, l_quantity,
+ *    0.2 * avg(l_quantity)over(partition by p_partkey)
+ *    AS avg_l_quantity
+ *    FROM lineitem, part
+ *    WHERE p_partkey = l_partkey and
+ *    p_brand = 'Brand#23' and
+ *    p_container = 'MED BOX') t
+ * WHERE l_quantity < avg_l_quantity;
+ */
+
+public class AggScalarSubQueryToWindowFunction extends 
DefaultPlanRewriter<JobContext> implements CustomRewriter {
+    private static final ImmutableSet<Class<? extends AggregateFunction>> 
SUPPORTED_FUNCTION = ImmutableSet.of(
+            Min.class, Max.class, Count.class, Sum.class, Avg.class
+    );
+    private static final ImmutableSet<Class<? extends LogicalPlan>> 
LEFT_SUPPORTED_PLAN = ImmutableSet.of(
+            LogicalOlapScan.class, LogicalLimit.class, LogicalJoin.class, 
LogicalProject.class
+    );
+    private static final ImmutableSet<Class<? extends LogicalPlan>> 
RIGHT_SUPPORTED_PLAN = ImmutableSet.of(
+            LogicalOlapScan.class, LogicalJoin.class, LogicalProject.class, 
LogicalAggregate.class, LogicalFilter.class
+    );
+    private List<LogicalPlan> outerPlans = null;
+    private List<LogicalPlan> innerPlans = null;
+    private LogicalAggregate aggOp = null;
+    private List<AggregateFunction> functions = null;
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext context) {
+        return plan.accept(this, context);
+    }
+
+    @Override
+    public Plan visitLogicalFilter(LogicalFilter filter, JobContext context) {
+        if (!checkPattern(filter)) {
+            return filter;
+        }
+        LogicalFilter<LogicalProject<LogicalApply<Plan, 
LogicalAggregate<Plan>>>> node
+                = ((LogicalFilter<LogicalProject<LogicalApply<Plan, 
LogicalAggregate<Plan>>>>) filter);
+        if (!check(node)) {
+            return filter;
+        }
+        return trans(node);
+    }
+
+    private boolean checkPattern(LogicalFilter filter) {
+        if (!(filter.child() instanceof LogicalProject)) {
+            return false;
+        }
+        LogicalProject project = (LogicalProject) filter.child();
+        if (project.child() == null || !(project.child() instanceof 
LogicalApply)) {
+            return false;
+        }
+        // filter(apply()) may be also ok.
+        // but nereids will match the pattern filter(project(apply()))
+        LogicalApply apply = ((LogicalApply<?, ?>) project.child());
+        if (!apply.isScalar() || !apply.isCorrelated() || 
!apply.getSubCorrespondingConject().isPresent()) {
+            return false;
+        }
+        return apply.left() != null && apply.right() instanceof 
LogicalAggregate
+                && ((LogicalAggregate<?>) apply.right()).child() != null;
+    }
+
+    private boolean check(LogicalFilter<LogicalProject<LogicalApply<Plan, 
LogicalAggregate<Plan>>>> node) {
+        LogicalApply<?, ?> apply = node.child().child();
+        LogicalPlan outer = ((LogicalPlan) apply.child(0));
+        LogicalPlan inner = ((LogicalPlan) apply.child(1));
+        outerPlans = new PlanCollector().collect(outer);
+        innerPlans = new PlanCollector().collect(inner);
+        LogicalFilter outerFilter = node;
+        Optional<LogicalFilter> innerFilter = innerPlans.stream()
+                .filter(LogicalFilter.class::isInstance)
+                .map(LogicalFilter.class::cast).findFirst();
+        return innerFilter.filter(
+                logicalFilter -> checkPlanType() && checkAggType()

Review Comment:
   idea recommends it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to