This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 81bcb9d4909916dd4bdaee45228838e9e405c69d
Author: Xujian Duan <50550370+darvend...@users.noreply.github.com>
AuthorDate: Fri May 17 16:53:37 2024 +0800

    [opt](planner)(Nereids) support auto aggregation for random distributed 
table (#33630)
    
    support auto aggregation for querying detail data of random distributed 
table:
    the same key column will return only one row.
---
 .../org/apache/doris/analysis/StmtRewriter.java    | 236 +++++++++++++++++++
 .../doris/nereids/jobs/executor/Analyzer.java      |   3 +
 .../org/apache/doris/nereids/rules/RuleType.java   |   5 +-
 .../BuildAggForRandomDistributedTable.java         | 257 +++++++++++++++++++++
 .../java/org/apache/doris/qe/StmtExecutor.java     |   9 +-
 .../aggregate/select_random_distributed_tbl.out    | 217 +++++++++++++++++
 .../aggregate/select_random_distributed_tbl.groovy | 134 +++++++++++
 7 files changed, 857 insertions(+), 4 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/StmtRewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/StmtRewriter.java
index 93823cf398c..8fcd54b4a1d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/StmtRewriter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/StmtRewriter.java
@@ -20,9 +20,16 @@
 
 package org.apache.doris.analysis;
 
+import org.apache.doris.catalog.AggStateType;
+import org.apache.doris.catalog.AggregateType;
 import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.DistributionInfo;
 import org.apache.doris.catalog.Env;
+import org.apache.doris.catalog.FunctionSet;
+import org.apache.doris.catalog.KeysType;
+import org.apache.doris.catalog.OlapTable;
 import org.apache.doris.catalog.ScalarType;
+import org.apache.doris.catalog.TableIf;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.AnalysisException;
 import org.apache.doris.common.TableAliasGenerator;
@@ -1366,4 +1373,233 @@ public class StmtRewriter {
         }
         return reAnalyze;
     }
+
+    /**
+     *
+     * @param column the column of SlotRef
+     * @param selectList new selectList for selectStmt
+     * @param groupByExprs group by Exprs for selectStmt
+     * @return true if ref can be rewritten
+     */
+    private static boolean rewriteSelectList(Column column, SelectList 
selectList, ArrayList<Expr> groupByExprs) {
+        SlotRef slot = new SlotRef(null, column.getName());
+        if (column.isKey()) {
+            selectList.addItem(new SelectListItem(slot, column.getName()));
+            groupByExprs.add(slot);
+            return true;
+        } else if (column.isAggregated()) {
+            FunctionCallExpr func = generateAggFunction(slot, column);
+            if (func != null) {
+                selectList.addItem(new SelectListItem(func, column.getName()));
+                return true;
+            }
+        }
+        return false;
+    }
+
+    /**
+     * rewrite stmt for querying random distributed table, construct an 
aggregation node for pre-agg
+     * * CREATE TABLE `tbl` (
+     *   `k1` BIGINT NULL DEFAULT "10",
+     *   `k3` SMALLINT NULL,
+     *   `a` BIGINT SUM NULL DEFAULT "0"
+     * ) ENGINE=OLAP
+     * AGGREGATE KEY(`k1`, `k2`)
+     * DISTRIBUTED BY RANDOM BUCKETS 1
+     * PROPERTIES (
+     * "replication_allocation" = "tag.location.default: 1"
+     * )
+     * e.g.,
+     * original: select * from tbl
+     * rewrite: select * from (select k1, k2, sum(pv) from tbl group by k1, 
k2) t
+     * do not rewrite if no need two phase agg:
+     * e.g.,
+     *     1. select max(k1) from tbl
+     *     2. select sum(a) from tbl
+     *
+     * @param statementBase stmt to rewrite
+     * @param analyzer the analyzer
+     * @return true if rewritten
+     * @throws UserException
+     */
+    public static boolean rewriteForRandomDistribution(StatementBase 
statementBase, Analyzer analyzer)
+            throws UserException {
+        boolean reAnalyze = false;
+        if (!(statementBase instanceof SelectStmt)) {
+            return false;
+        }
+        SelectStmt selectStmt = (SelectStmt) statementBase;
+        for (int i = 0; i < selectStmt.fromClause.size(); i++) {
+            TableRef tableRef = selectStmt.fromClause.get(i);
+            // Recursively rewrite subquery
+            if (tableRef instanceof InlineViewRef) {
+                InlineViewRef viewRef = (InlineViewRef) tableRef;
+                if (rewriteForRandomDistribution(viewRef.getQueryStmt(), 
viewRef.getAnalyzer())) {
+                    reAnalyze = true;
+                }
+                continue;
+            }
+            TableIf table = tableRef.getTable();
+            if (!(table instanceof OlapTable)) {
+                continue;
+            }
+            // only rewrite random distributed AGG_KEY table
+            OlapTable olapTable = (OlapTable) table;
+            if (olapTable.getKeysType() != KeysType.AGG_KEYS) {
+                continue;
+            }
+            DistributionInfo distributionInfo = 
olapTable.getDefaultDistributionInfo();
+            if (distributionInfo.getType() != 
DistributionInfo.DistributionInfoType.RANDOM) {
+                continue;
+            }
+
+            // check agg function and column agg type
+            AggregateInfo aggInfo = selectStmt.getAggInfo();
+            GroupByClause groupByClause = selectStmt.getGroupByClause();
+            boolean aggTypeMatch = true;
+            if (aggInfo != null || groupByClause != null) {
+                if (aggInfo != null) {
+                    ArrayList<FunctionCallExpr> aggExprs = 
aggInfo.getAggregateExprs();
+                    if (aggExprs.stream().anyMatch(expr -> 
!aggTypeMatch(expr.getFnName().getFunction(), expr))) {
+                        aggTypeMatch = false;
+                    }
+                    List<Expr> groupExprs = aggInfo.getGroupingExprs();
+                    if (groupExprs.stream().anyMatch(expr -> 
!isKeyOrConstantExpr(expr))) {
+                        aggTypeMatch = false;
+                    }
+                }
+                if (groupByClause != null) {
+                    List<Expr> groupByExprs = groupByClause.getGroupingExprs();
+                    if (groupByExprs.stream().anyMatch(expr -> 
!isKeyOrConstantExpr(expr))) {
+                        aggTypeMatch = false;
+                    }
+                }
+                if (aggTypeMatch) {
+                    continue;
+                }
+            }
+            // construct a new InlineViewRef for pre-agg
+            boolean canRewrite = true;
+            SelectList selectList = new SelectList();
+            ArrayList<Expr> groupingExprs = new ArrayList<>();
+            List<Column> columns = olapTable.getBaseSchema();
+            for (Column col : columns) {
+                if (!rewriteSelectList(col, selectList, groupingExprs)) {
+                    canRewrite = false;
+                    break;
+                }
+            }
+            if (!canRewrite) {
+                continue;
+            }
+            Expr whereClause = selectStmt.getWhereClause() == null ? null : 
selectStmt.getWhereClause().clone();
+            SelectStmt newSelectSmt = new SelectStmt(selectList,
+                    new FromClause(Lists.newArrayList(tableRef)),
+                    whereClause,
+                    new GroupByClause(groupingExprs, 
GroupByClause.GroupingType.GROUP_BY),
+                    null,
+                    null,
+                    LimitElement.NO_LIMIT);
+            InlineViewRef inlineViewRef = new 
InlineViewRef(tableRef.getAliasAsName().getTbl(), newSelectSmt);
+            inlineViewRef.setJoinOp(tableRef.getJoinOp());
+            inlineViewRef.setLeftTblRef(tableRef.getLeftTblRef());
+            inlineViewRef.setOnClause(tableRef.getOnClause());
+            tableRef.setOnClause(null);
+            tableRef.setLeftTblRef(null);
+            tableRef.setOnClause(null);
+            if (selectStmt.fromClause.size() > i + 1) {
+                selectStmt.fromClause.get(i + 1).setLeftTblRef(inlineViewRef);
+            }
+            selectStmt.fromClause.set(i, inlineViewRef);
+            selectStmt.analyze(analyzer);
+            reAnalyze = true;
+        }
+        return reAnalyze;
+    }
+
+    /**
+     * check if the agg type of functionCall match the agg type of column
+     * @param functionName the functionName of functionCall
+     * @param expr FunctionCallExpr
+     * @return true if agg type match
+     */
+    private static boolean aggTypeMatch(String functionName, Expr expr) {
+        if (expr.getChildren().isEmpty()) {
+            if (expr instanceof SlotRef) {
+                Column col = ((SlotRef) expr).getDesc().getColumn();
+                if (col.isKey()) {
+                    return functionName.equalsIgnoreCase("MAX")
+                            || functionName.equalsIgnoreCase("MIN");
+                }
+                if (col.isAggregated()) {
+                    AggregateType aggType = col.getAggregationType();
+                    // agg type not mach
+                    if (aggType == AggregateType.GENERIC) {
+                        return col.getType().isAggStateType();
+                    }
+                    if (aggType == AggregateType.HLL_UNION) {
+                        return 
functionName.equalsIgnoreCase(FunctionSet.HLL_UNION)
+                                || 
functionName.equalsIgnoreCase(FunctionSet.HLL_UNION_AGG);
+                    }
+                    if (aggType == AggregateType.BITMAP_UNION) {
+                        return 
functionName.equalsIgnoreCase(FunctionSet.BITMAP_UNION)
+                                || 
functionName.equalsIgnoreCase(FunctionSet.BITMAP_UNION_COUNT)
+                                || 
functionName.equalsIgnoreCase(FunctionSet.BITMAP_INTERSECT);
+                    }
+                    return functionName.equalsIgnoreCase(aggType.name());
+                }
+            }
+            return false;
+        }
+        List<Expr> children = expr.getChildren();
+        return children.stream().allMatch(child -> aggTypeMatch(functionName, 
child));
+    }
+
+    /**
+     * check if the columns in expr is key column or constant, if group by 
clause contains value column, need rewrite
+     *
+     * @param expr expr to check
+     * @return true if all columns is key column or constant
+     */
+    private static boolean isKeyOrConstantExpr(Expr expr) {
+        if (expr instanceof SlotRef) {
+            Column col = ((SlotRef) expr).getDesc().getColumn();
+            return col.isKey();
+        } else if (expr.isConstant()) {
+            return true;
+        }
+        List<Expr> children = expr.getChildren();
+        return children.stream().allMatch(StmtRewriter::isKeyOrConstantExpr);
+    }
+
+    /**
+     * generate aggregation function according to the aggType of column
+     *
+     * @param slot slot of column
+     * @return aggFunction generated
+     */
+    private static FunctionCallExpr generateAggFunction(SlotRef slot, Column 
column) {
+        AggregateType aggregateType = column.getAggregationType();
+        switch (aggregateType) {
+            case SUM:
+            case MAX:
+            case MIN:
+            case HLL_UNION:
+            case BITMAP_UNION:
+            case QUANTILE_UNION:
+                FunctionName funcName = new 
FunctionName(aggregateType.toString().toLowerCase());
+                return new FunctionCallExpr(funcName, new 
FunctionParams(false, Lists.newArrayList(slot)));
+            case GENERIC:
+                Type type = column.getType();
+                if (!type.isAggStateType()) {
+                    return null;
+                }
+                AggStateType aggState = (AggStateType) type;
+                // use AGGREGATE_FUNCTION_UNION to aggregate multiple 
agg_state into one
+                FunctionName functionName = new 
FunctionName(aggState.getFunctionName() + "_union");
+                return new FunctionCallExpr(functionName, new 
FunctionParams(false, Lists.newArrayList(slot)));
+            default:
+                return null;
+        }
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
index ac0a4421071..77d23464e65 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
@@ -26,6 +26,7 @@ import org.apache.doris.nereids.rules.analysis.BindRelation;
 import 
org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
 import org.apache.doris.nereids.rules.analysis.BindSink;
 import org.apache.doris.nereids.rules.analysis.BindSlotWithPaths;
+import 
org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable;
 import org.apache.doris.nereids.rules.analysis.CheckAfterBind;
 import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
 import org.apache.doris.nereids.rules.analysis.CheckPolicy;
@@ -176,6 +177,8 @@ public class Analyzer extends AbstractBatchJobExecutor {
             topDown(new EliminateGroupByConstant()),
 
             topDown(new SimplifyAggGroupBy()),
+            // run BuildAggForRandomDistributedTable before NormalizeAggregate 
in order to optimize the agg plan
+            topDown(new BuildAggForRandomDistributedTable()),
             topDown(new NormalizeAggregate()),
             topDown(new HavingToFilter()),
             bottomUp(new SemiJoinCommute()),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index c688d7d5b3f..c4eb7fe9b06 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -313,7 +313,10 @@ public enum RuleType {
 
     // topn opts
     DEFER_MATERIALIZE_TOP_N_RESULT(RuleTypeClass.REWRITE),
-
+    // pre agg for random distributed table
+    BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN(RuleTypeClass.REWRITE),
+    BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN(RuleTypeClass.REWRITE),
+    BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN(RuleTypeClass.REWRITE),
     // exploration rules
     TEST_EXPLORATION(RuleTypeClass.EXPLORATION),
     OR_EXPANSION(RuleTypeClass.EXPLORATION),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
new file mode 100644
index 00000000000..86c89e49d3d
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
@@ -0,0 +1,257 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.analysis;
+
+import org.apache.doris.catalog.AggStateType;
+import org.apache.doris.catalog.AggregateType;
+import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.DistributionInfo;
+import org.apache.doris.catalog.DistributionInfo.DistributionInfoType;
+import org.apache.doris.catalog.Env;
+import org.apache.doris.catalog.FunctionRegistry;
+import org.apache.doris.catalog.KeysType;
+import org.apache.doris.catalog.OlapTable;
+import org.apache.doris.catalog.Type;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
+import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.HllFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * build agg plan for querying random distributed table
+ */
+public class BuildAggForRandomDistributedTable implements AnalysisRuleFactory {
+
+    @Override
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                // Project(Scan) -> project(agg(scan))
+                logicalProject(logicalOlapScan()).when(project -> 
isRandomDistributedTbl(project.child()))
+                        .then(project -> preAggForRandomDistribution(project, 
project.child()))
+                        
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN),
+                // agg(scan) -> agg(agg(scan)), agg(agg) may optimized by 
MergeAggregate
+                logicalAggregate(logicalOlapScan()).when(agg -> 
isRandomDistributedTbl(agg.child())).whenNot(agg -> {
+                    Set<AggregateFunction> functions = 
agg.getAggregateFunctions();
+                    List<Expression> groupByExprs = 
agg.getGroupByExpressions();
+                    // check if need generate an inner agg plan or not
+                    // should not rewrite twice if we had rewritten olapScan 
to aggregate(olapScan)
+                    return functions.stream().allMatch(this::aggTypeMatch) && 
groupByExprs.stream()
+                                    .allMatch(this::isKeyOrConstantExpr);
+                })
+                        .then(agg -> preAggForRandomDistribution(agg, 
agg.child()))
+                        
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN),
+                // filter(scan) -> filter(agg(scan))
+                logicalFilter(logicalOlapScan()).when(filter -> 
isRandomDistributedTbl(filter.child()))
+                        .then(filter -> preAggForRandomDistribution(filter, 
filter.child()))
+                        
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN));
+
+    }
+
+    /**
+     * check the olapTable of olapScan is randomDistributed table
+     *
+     * @param olapScan olap scan plan
+     * @return true if olapTable is randomDistributed table
+     */
+    private boolean isRandomDistributedTbl(LogicalOlapScan olapScan) {
+        OlapTable olapTable = olapScan.getTable();
+        KeysType keysType = olapTable.getKeysType();
+        DistributionInfo distributionInfo = 
olapTable.getDefaultDistributionInfo();
+        return keysType == KeysType.AGG_KEYS && distributionInfo.getType() == 
DistributionInfoType.RANDOM;
+    }
+
+    /**
+     * add LogicalAggregate above olapScan for preAgg
+     *
+     * @param logicalPlan parent plan of olapScan
+     * @param olapScan olap scan plan, it may be LogicalProject, 
LogicalFilter, LogicalAggregate
+     * @return rewritten plan
+     */
+    private Plan preAggForRandomDistribution(LogicalPlan logicalPlan, 
LogicalOlapScan olapScan) {
+        OlapTable olapTable = olapScan.getTable();
+        List<Slot> childOutputSlots = olapScan.computeOutput();
+        List<Expression> groupByExpressions = new ArrayList<>();
+        List<NamedExpression> outputExpressions = new ArrayList<>();
+        List<Column> columns = olapTable.getBaseSchema();
+
+        for (Column col : columns) {
+            // use exist slot in the plan
+            SlotReference slot = SlotReference.fromColumn(olapTable, col, 
col.getName(), olapScan.getQualifier());
+            ExprId exprId = slot.getExprId();
+            for (Slot childSlot : childOutputSlots) {
+                if (childSlot instanceof SlotReference && ((SlotReference) 
childSlot).getName() == col.getName()) {
+                    exprId = childSlot.getExprId();
+                    slot = slot.withExprId(exprId);
+                    break;
+                }
+            }
+            if (col.isKey()) {
+                groupByExpressions.add(slot);
+                outputExpressions.add(slot);
+            } else {
+                Expression function = generateAggFunction(slot, col);
+                // DO NOT rewrite
+                if (function == null) {
+                    return logicalPlan;
+                }
+                Alias alias = new Alias(exprId, function, col.getName());
+                outputExpressions.add(alias);
+            }
+        }
+        LogicalAggregate<LogicalOlapScan> aggregate = new 
LogicalAggregate<>(groupByExpressions, outputExpressions,
+                olapScan);
+        return logicalPlan.withChildren(aggregate);
+    }
+
+    /**
+     * generate aggregation function according to the aggType of column
+     *
+     * @param slot slot of column
+     * @return aggFunction generated
+     */
+    private Expression generateAggFunction(SlotReference slot, Column column) {
+        AggregateType aggregateType = column.getAggregationType();
+        switch (aggregateType) {
+            case SUM:
+                return new Sum(slot);
+            case MAX:
+                return new Max(slot);
+            case MIN:
+                return new Min(slot);
+            case HLL_UNION:
+                return new HllUnion(slot);
+            case BITMAP_UNION:
+                return new BitmapUnion(slot);
+            case QUANTILE_UNION:
+                return new QuantileUnion(slot);
+            case GENERIC:
+                Type type = column.getType();
+                if (!type.isAggStateType()) {
+                    return null;
+                }
+                AggStateType aggState = (AggStateType) type;
+                // use AGGREGATE_FUNCTION_UNION to aggregate multiple 
agg_state into one
+                String funcName = aggState.getFunctionName() + 
AggCombinerFunctionBuilder.UNION_SUFFIX;
+                FunctionRegistry functionRegistry = 
Env.getCurrentEnv().getFunctionRegistry();
+                FunctionBuilder builder = 
functionRegistry.findFunctionBuilder(funcName, slot);
+                return builder.build(funcName, ImmutableList.of(slot)).first;
+            default:
+                return null;
+        }
+    }
+
+    /**
+     * if the agg type of AggregateFunction is as same as the agg type of 
column, DO NOT need to rewrite
+     *
+     * @param function agg function to check
+     * @return true if agg type match
+     */
+    private boolean aggTypeMatch(AggregateFunction function) {
+        List<Expression> children = function.children();
+        if (function.getName().equalsIgnoreCase("count")) {
+            Count count = (Count) function;
+            // do not rewrite for count distinct for key column
+            if (count.isDistinct()) {
+                return children.stream().allMatch(this::isKeyOrConstantExpr);
+            }
+            if (count.isStar()) {
+                return false;
+            }
+        }
+        return children.stream().allMatch(child -> aggTypeMatch(function, 
child));
+    }
+
+    /**
+     * check if the agg type of functionCall match the agg type of column
+     *
+     * @param function the functionCall
+     * @param expression expr to check
+     * @return true if agg type match
+     */
+    private boolean aggTypeMatch(AggregateFunction function, Expression 
expression) {
+        if (expression.children().isEmpty()) {
+            if (expression instanceof SlotReference && ((SlotReference) 
expression).getColumn().isPresent()) {
+                Column col = ((SlotReference) expression).getColumn().get();
+                String functionName = function.getName();
+                if (col.isKey()) {
+                    return functionName.equalsIgnoreCase("max") || 
functionName.equalsIgnoreCase("min");
+                }
+                if (col.isAggregated()) {
+                    AggregateType aggType = col.getAggregationType();
+                    // agg type not mach
+                    if (aggType == AggregateType.GENERIC) {
+                        return col.getType().isAggStateType();
+                    }
+                    if (aggType == AggregateType.HLL_UNION) {
+                        return function instanceof HllFunction;
+                    }
+                    if (aggType == AggregateType.BITMAP_UNION) {
+                        return function instanceof BitmapFunction;
+                    }
+                    return functionName.equalsIgnoreCase(aggType.name());
+                }
+            }
+            return false;
+        }
+        List<Expression> children = expression.children();
+        return children.stream().allMatch(child -> aggTypeMatch(function, 
child));
+    }
+
+    /**
+     * check if the columns in expr is key column or constant, if group by 
clause contains value column, need rewrite
+     *
+     * @param expr expr to check
+     * @return true if all columns is key column or constant
+     */
+    private boolean isKeyOrConstantExpr(Expression expr) {
+        if (expr instanceof SlotReference && ((SlotReference) 
expr).getColumn().isPresent()) {
+            Column col = ((SlotReference) expr).getColumn().get();
+            return col.isKey();
+        } else if (expr.isConstant()) {
+            return true;
+        }
+        List<Expression> children = expr.children();
+        return children.stream().allMatch(this::isKeyOrConstantExpr);
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
index 1cf499f0c6c..c2b0e9dd444 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
@@ -1305,21 +1305,24 @@ public class StmtExecutor {
                 reAnalyze = true;
             }
             if (parsedStmt instanceof SelectStmt) {
-                if (StmtRewriter.rewriteByPolicy(parsedStmt, analyzer)) {
+                if (StmtRewriter.rewriteByPolicy(parsedStmt, analyzer)
+                        || 
StmtRewriter.rewriteForRandomDistribution(parsedStmt, analyzer)) {
                     reAnalyze = true;
                 }
             }
             if (parsedStmt instanceof SetOperationStmt) {
                 List<SetOperationStmt.SetOperand> operands = 
((SetOperationStmt) parsedStmt).getOperands();
                 for (SetOperationStmt.SetOperand operand : operands) {
-                    if (StmtRewriter.rewriteByPolicy(operand.getQueryStmt(), 
analyzer)) {
+                    if (StmtRewriter.rewriteByPolicy(operand.getQueryStmt(), 
analyzer)
+                            || 
StmtRewriter.rewriteForRandomDistribution(operand.getQueryStmt(), analyzer)) {
                         reAnalyze = true;
                     }
                 }
             }
             if (parsedStmt instanceof InsertStmt) {
                 QueryStmt queryStmt = ((InsertStmt) parsedStmt).getQueryStmt();
-                if (queryStmt != null && 
StmtRewriter.rewriteByPolicy(queryStmt, analyzer)) {
+                if (queryStmt != null && 
(StmtRewriter.rewriteByPolicy(queryStmt, analyzer)
+                        || 
StmtRewriter.rewriteForRandomDistribution(queryStmt, analyzer))) {
                     reAnalyze = true;
                 }
             }
diff --git 
a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out 
b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
new file mode 100644
index 00000000000..1afb2a06762
--- /dev/null
+++ b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
@@ -0,0 +1,217 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !sql_1 --
+1      a       6       3       1       \N      \N      \N
+2      b       15      6       4       \N      \N      \N
+
+-- !sql_2 --
+1      a       6       3       1       \N      \N      \N
+2      b       15      6       4       \N      \N      \N
+
+-- !sql_3 --
+2      a       6
+3      b       15
+
+-- !sql_4 --
+1      a       7
+2      b       16
+
+-- !sql_5 --
+1      6       3       1       2.0     3       3       2.0
+2      15      6       4       5.0     3       3       5.0
+
+-- !sql_6 --
+2
+
+-- !sql_7 --
+2
+
+-- !sql_8 --
+2
+
+-- !sql_9 --
+15
+
+-- !sql_10 --
+9
+
+-- !sql_11 --
+5
+
+-- !sql_12 --
+1      6
+4      15
+
+-- !sql_13 --
+2
+
+-- !sql_14 --
+2
+
+-- !sql_15 --
+2
+
+-- !sql_16 --
+2
+
+-- !sql_1 --
+1      a       6       3       1       \N      \N      \N
+2      b       15      6       4       \N      \N      \N
+
+-- !sql_2 --
+1      a       6       3       1       \N      \N      \N
+2      b       15      6       4       \N      \N      \N
+
+-- !sql_3 --
+2      a       6
+3      b       15
+
+-- !sql_4 --
+1      a       7
+2      b       16
+
+-- !sql_5 --
+1      6       3       1       2.0     3       3       2.0
+2      15      6       4       5.0     3       3       5.0
+
+-- !sql_6 --
+2
+
+-- !sql_7 --
+2
+
+-- !sql_8 --
+2
+
+-- !sql_9 --
+15
+
+-- !sql_10 --
+9
+
+-- !sql_11 --
+5
+
+-- !sql_12 --
+1      6
+4      15
+
+-- !sql_13 --
+2
+
+-- !sql_14 --
+2
+
+-- !sql_15 --
+2
+
+-- !sql_16 --
+2
+
+-- !sql_1 --
+1      a       6       3       1       \N      \N      \N
+2      b       15      6       4       \N      \N      \N
+
+-- !sql_2 --
+1      a       6       3       1       \N      \N      \N
+2      b       15      6       4       \N      \N      \N
+
+-- !sql_3 --
+2      a       6
+3      b       15
+
+-- !sql_4 --
+1      a       7
+2      b       16
+
+-- !sql_5 --
+1      6       3       1       2.0     3       3       2.0
+2      15      6       4       5.0     3       3       5.0
+
+-- !sql_6 --
+2
+
+-- !sql_7 --
+2
+
+-- !sql_8 --
+2
+
+-- !sql_9 --
+15
+
+-- !sql_10 --
+9
+
+-- !sql_11 --
+5
+
+-- !sql_12 --
+1      6
+4      15
+
+-- !sql_13 --
+2
+
+-- !sql_14 --
+2
+
+-- !sql_15 --
+2
+
+-- !sql_16 --
+2
+
+-- !sql_1 --
+1      a       6       3       1       \N      \N      \N
+2      b       15      6       4       \N      \N      \N
+
+-- !sql_2 --
+1      a       6       3       1       \N      \N      \N
+2      b       15      6       4       \N      \N      \N
+
+-- !sql_3 --
+2      a       6
+3      b       15
+
+-- !sql_4 --
+1      a       7
+2      b       16
+
+-- !sql_5 --
+1      6       3       1       2.0     3       3       2.0
+2      15      6       4       5.0     3       3       5.0
+
+-- !sql_6 --
+2
+
+-- !sql_7 --
+2
+
+-- !sql_8 --
+2
+
+-- !sql_9 --
+15
+
+-- !sql_10 --
+9
+
+-- !sql_11 --
+5
+
+-- !sql_12 --
+1      6
+4      15
+
+-- !sql_13 --
+2
+
+-- !sql_14 --
+2
+
+-- !sql_15 --
+2
+
+-- !sql_16 --
+2
+
diff --git 
a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
 
b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
new file mode 100644
index 00000000000..ff0df74589a
--- /dev/null
+++ 
b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
@@ -0,0 +1,134 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("select_random_distributed_tbl") {
+    def tableName = "random_distributed_tbl_test"
+
+    sql "drop table if exists ${tableName};"
+    sql "set enable_agg_state=true;"
+    sql """ admin set frontend config("enable_quantile_state_type"="true"); """
+    sql """
+    CREATE TABLE ${tableName}
+    (
+        `k1` LARGEINT NOT NULL,
+        `k2` VARCHAR(20) NULL,
+        `v_sum` BIGINT SUM NULL DEFAULT "0",
+        `v_max` INT MAX NULL DEFAULT "0",
+        `v_min` INT MIN NULL DEFAULT "99999",
+        `v_generic` AGG_STATE<avg(int NULL)> GENERIC,
+        `v_hll` HLL HLL_UNION NOT NULL,
+        `v_bitmap` BITMAP BITMAP_UNION NOT NULL,
+        `v_quantile_union` QUANTILE_STATE QUANTILE_UNION NOT NULL
+    ) ENGINE=OLAP
+    AGGREGATE KEY(`k1`, `k2`)
+    COMMENT 'OLAP'
+    DISTRIBUTED BY RANDOM BUCKETS 10
+    PROPERTIES (
+    "replication_allocation" = "tag.location.default: 1"
+    );
+    """
+    
+    sql """ insert into ${tableName} 
values(1,"a",1,1,1,avg_state(1),hll_hash(1),bitmap_hash(1),to_quantile_state(1, 
2048)) """
+    sql """ insert into ${tableName} 
values(1,"a",2,2,2,avg_state(2),hll_hash(2),bitmap_hash(2),to_quantile_state(2, 
2048)) """
+    sql """ insert into ${tableName} 
values(1,"a",3,3,3,avg_state(3),hll_hash(3),bitmap_hash(3),to_quantile_state(3, 
2048)) """
+    sql """ insert into ${tableName} 
values(2,"b",4,4,4,avg_state(4),hll_hash(4),bitmap_hash(4),to_quantile_state(4, 
2048)) """
+    sql """ insert into ${tableName} 
values(2,"b",5,5,5,avg_state(5),hll_hash(5),bitmap_hash(5),to_quantile_state(5, 
2048)) """
+    sql """ insert into ${tableName} 
values(2,"b",6,6,6,avg_state(6),hll_hash(6),bitmap_hash(6),to_quantile_state(6, 
2048)) """
+
+    for (int i = 0; i < 2; ++i) {
+        if (i == 0) {
+            // test legacy planner
+            sql "set enable_nereids_planner = false;"
+        } else if (i == 1) {
+            // test nereids planner
+            sql "set enable_nereids_planner = true;"
+        }
+
+        def whereStr = ""
+        for (int j = 0; j < 2; ++j) {
+            if (j == 1) {
+                // test with filter
+                whereStr = "where k1 > 0"
+            }
+            def sql1 = "select * except (v_generic) from ${tableName} 
${whereStr} order by k1, k2"
+            qt_sql_1 "${sql1}"
+            def res1 = sql """ explain ${sql1} """
+            assertTrue(res1.toString().contains("VAGGREGATE"))
+
+            def sql2 = "select k1 ,k2 ,v_sum ,v_max ,v_min ,v_hll ,v_bitmap 
,v_quantile_union from ${tableName} ${whereStr} order by k1, k2"
+            qt_sql_2 "${sql2}"
+            def res2 = sql """ explain ${sql2} """
+            assertTrue(res2.toString().contains("VAGGREGATE"))
+
+            def sql3 = "select k1+1, k2, v_sum from ${tableName} ${whereStr} 
order by k1, k2"
+            qt_sql_3 "${sql3}"
+            def res3 = sql """ explain ${sql3} """
+            assertTrue(res3.toString().contains("VAGGREGATE"))
+
+            def sql4 = "select k1, k2, v_sum+1 from ${tableName} ${whereStr} 
order by k1, k2"
+            qt_sql_4 "${sql4}"
+            def res4 = sql """ explain ${sql4} """
+            assertTrue(res4.toString().contains("VAGGREGATE"))
+
+            def sql5 =  """ select k1, sum(v_sum), max(v_max), min(v_min), 
avg_merge(v_generic), 
+                hll_union_agg(v_hll), bitmap_union_count(v_bitmap), 
quantile_percent(quantile_union(v_quantile_union),0.5) 
+                from ${tableName} ${whereStr} group by k1 order by k1 """
+            qt_sql_5 "${sql5}"
+
+            def sql6 = "select count(1) from ${tableName} ${whereStr}"
+            qt_sql_6 "${sql6}"
+
+            def sql7 = "select count(*) from ${tableName} ${whereStr}"
+            qt_sql_7 "${sql7}"
+
+            def sql8 = "select max(k1) from ${tableName} ${whereStr}"
+            qt_sql_8 "${sql8}"
+            def res8 = sql """ explain ${sql8} """
+            // no pre agg
+            assertFalse(res8.toString().contains("sum"))
+
+            def sql9 = "select max(v_sum) from ${tableName} ${whereStr}"
+            qt_sql_9 "${sql9}"
+            def res9 = sql """ explain ${sql9} """
+            assertTrue(res9.toString().contains("sum"))
+
+            def sql10 = "select sum(v_max) from ${tableName} ${whereStr}"
+            qt_sql_10 "${sql10}"
+
+            def sql11 = "select sum(v_min) from ${tableName} ${whereStr}"
+            qt_sql_11 "${sql11}"
+
+            // test group by value
+            def sql12 = "select v_min, sum(v_sum) from ${tableName} 
${whereStr} group by v_min order by v_min"
+            qt_sql_12 "${sql12}"
+
+            def sql13 = "select count(k1) from ${tableName} ${whereStr}"
+            qt_sql_13 "${sql13}"
+
+            def sql14 = "select count(distinct k1) from ${tableName} 
${whereStr}"
+            qt_sql_14 "${sql14}"
+
+            def sql15 = "select count(v_sum) from ${tableName} ${whereStr}"
+            qt_sql_15 "${sql15}"
+
+            def sql16 = "select count(distinct v_sum) from ${tableName} 
${whereStr}"
+            qt_sql_16 "${sql16}"
+        }
+    }
+
+    sql "drop table ${tableName};"
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to