This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch dev-1.1.2
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/dev-1.1.2 by this push:
     new ccf2f82be4 [fix](function)fix max_by function bug  (#11745)
ccf2f82be4 is described below

commit ccf2f82be43d9170cb0c3373bd27338573e6d7c9
Author: starocean999 <40539150+starocean...@users.noreply.github.com>
AuthorDate: Mon Aug 15 09:06:54 2022 +0800

    [fix](function)fix max_by function bug  (#11745)
    
    This pr does the same thing as #10650. Because the code base is so 
different that it's easier to make the changes based on dev-1.1.2 than 
cherry-pick
---
 be/src/vec/CMakeLists.txt                          |   1 +
 be/src/vec/data_types/data_type_factory.cpp        | 100 +++++++++++++++++++++
 be/src/vec/data_types/data_type_factory.hpp        |   8 +-
 be/src/vec/exprs/vectorized_agg_fn.cpp             |  28 +++---
 be/src/vec/exprs/vectorized_agg_fn.h               |   2 +-
 .../org/apache/doris/analysis/AggregateInfo.java   |  18 ++--
 .../apache/doris/analysis/FunctionCallExpr.java    |  20 ++++-
 .../org/apache/doris/analysis/FunctionParams.java  |  17 ++++
 gensrc/thrift/Exprs.thrift                         |   1 +
 gensrc/thrift/Types.thrift                         |   1 +
 10 files changed, 166 insertions(+), 30 deletions(-)

diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index d8adfe6d40..56cfdfcc14 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -73,6 +73,7 @@ set(VEC_FILES
   data_types/nested_utils.cpp
   data_types/data_type_date.cpp
   data_types/data_type_date_time.cpp
+  data_types/data_type_factory.cpp
   exec/vaggregation_node.cpp
   exec/ves_http_scan_node.cpp
   exec/ves_http_scanner.cpp
diff --git a/be/src/vec/data_types/data_type_factory.cpp 
b/be/src/vec/data_types/data_type_factory.cpp
new file mode 100644
index 0000000000..d78679d377
--- /dev/null
+++ b/be/src/vec/data_types/data_type_factory.cpp
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// 
https://github.com/ClickHouse/ClickHouse/blob/master/src/DataTypes/DataTypeFactory.cpp
+// and modified by Doris
+
+#include "vec/data_types/data_type_factory.hpp"
+#include "runtime/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_bitmap.h"
+#include "vec/data_types/data_type_date.h"
+#include "vec/data_types/data_type_date_time.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_nothing.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+
+namespace doris::vectorized {
+
+DataTypePtr DataTypeFactory::create_data_type(const TypeDescriptor& col_desc, 
bool is_nullable) {
+    DataTypePtr nested = nullptr;
+    switch (col_desc.type) {
+    case TYPE_BOOLEAN:
+        nested = std::make_shared<vectorized::DataTypeUInt8>();
+        break;
+    case TYPE_TINYINT:
+        nested = std::make_shared<vectorized::DataTypeInt8>();
+        break;
+    case TYPE_SMALLINT:
+        nested = std::make_shared<vectorized::DataTypeInt16>();
+        break;
+    case TYPE_INT:
+        nested = std::make_shared<vectorized::DataTypeInt32>();
+        break;
+    case TYPE_FLOAT:
+        nested = std::make_shared<vectorized::DataTypeFloat32>();
+        break;
+    case TYPE_BIGINT:
+        nested = std::make_shared<vectorized::DataTypeInt64>();
+        break;
+    case TYPE_LARGEINT:
+        nested = std::make_shared<vectorized::DataTypeInt128>();
+        break;
+    case TYPE_DATE:
+        nested = std::make_shared<vectorized::DataTypeDate>();
+        break;
+    case TYPE_DATETIME:
+        nested = std::make_shared<vectorized::DataTypeDateTime>();
+        break;
+    case TYPE_TIME:
+    case TYPE_DOUBLE:
+        nested = std::make_shared<vectorized::DataTypeFloat64>();
+        break;
+    case TYPE_STRING:
+    case TYPE_CHAR:
+    case TYPE_VARCHAR:
+        nested = std::make_shared<vectorized::DataTypeString>();
+        break;
+    case TYPE_HLL:
+        nested = std::make_shared<vectorized::DataTypeHLL>();
+        break;
+    case TYPE_OBJECT:
+        nested = std::make_shared<vectorized::DataTypeBitMap>();
+        break;
+    case TYPE_DECIMALV2:
+        nested = 
std::make_shared<vectorized::DataTypeDecimal<vectorized::Decimal128>>(27, 9);
+        break;
+    // Just Mock A NULL Type in Vec Exec Engine
+    case TYPE_NULL:
+        nested = std::make_shared<vectorized::DataTypeUInt8>();
+        break;
+    case INVALID_TYPE:
+    default:
+        DCHECK(false) << "invalid PrimitiveType:" << (int)col_desc.type;
+        break;
+    }
+
+    if (nested && is_nullable) {
+        return std::make_shared<vectorized::DataTypeNullable>(nested);
+    }
+    return nested;
+}
+
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/data_types/data_type_factory.hpp 
b/be/src/vec/data_types/data_type_factory.hpp
index e06a962c2f..abe84cfe52 100644
--- a/be/src/vec/data_types/data_type_factory.hpp
+++ b/be/src/vec/data_types/data_type_factory.hpp
@@ -21,7 +21,7 @@
 #pragma once
 #include <mutex>
 #include <string>
-
+#include "runtime/types.h"
 #include "vec/data_types/data_type.h"
 #include "vec/data_types/data_type_date.h"
 #include "vec/data_types/data_type_date_time.h"
@@ -74,6 +74,12 @@ public:
         return _empty_string;
     }
 
+    DataTypePtr create_data_type(const TypeDescriptor& col_desc, bool 
is_nullable = true);
+
+    DataTypePtr create_data_type(const TTypeDesc& raw_type) {
+        return create_data_type(TypeDescriptor::from_thrift(raw_type), 
raw_type.is_nullable);
+    }
+
 private:
     void regist_data_type(const std::string& name, const DataTypePtr& 
data_type) {
         _data_type_map.emplace(name, data_type);
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp 
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 0987190069..cab89ae7ff 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -23,6 +23,7 @@
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
 #include "vec/columns/column_nullable.h"
 #include "vec/core/materialize_block.h"
+#include "vec/data_types/data_type_factory.hpp"
 #include "vec/data_types/data_type_nullable.h"
 #include "vec/exprs/vexpr.h"
 
@@ -32,18 +33,23 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
         : _fn(desc.fn),
           _is_merge(desc.agg_expr.is_merge_agg),
           _return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)),
-          
_intermediate_type(TypeDescriptor::from_thrift(desc.fn.aggregate_fn.intermediate_type)),
           _intermediate_slot_desc(nullptr),
           _output_slot_desc(nullptr),
           _exec_timer(nullptr),
           _merge_timer(nullptr),
           _expr_timer(nullptr) {
-        if (desc.__isset.is_nullable) {
-          _data_type = IDataType::from_thrift(_return_type.type, 
desc.is_nullable);
-        } else {
-          _data_type = IDataType::from_thrift(_return_type.type);
+    if (desc.__isset.is_nullable) {
+        _data_type = IDataType::from_thrift(_return_type.type, 
desc.is_nullable);
+    } else {
+        _data_type = IDataType::from_thrift(_return_type.type);
+    }
+    if (desc.agg_expr.__isset.param_types) {
+        auto& param_types = desc.agg_expr.param_types;
+        for (auto raw_type : param_types) {
+            
_argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type));
         }
     }
+}
 
 Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, 
AggFnEvaluator** result) {
     *result = pool->add(new AggFnEvaluator(desc.nodes[0]));
@@ -73,21 +79,21 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const 
RowDescriptor& desc, M
     Status status = VExpr::prepare(_input_exprs_ctxs, state, desc, 
mem_tracker);
     RETURN_IF_ERROR(status);
 
-    DataTypes argument_types;
-    argument_types.reserve(_input_exprs_ctxs.size());
+    DataTypes tmp_argument_types;
+    tmp_argument_types.reserve(_input_exprs_ctxs.size());
 
     std::vector<std::string_view> child_expr_name;
 
-    doris::vectorized::Array params;
     // prepare for argument
     for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
         auto data_type = _input_exprs_ctxs[i]->root()->data_type();
-        argument_types.emplace_back(data_type);
+        tmp_argument_types.emplace_back(data_type);
         
child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name());
     }
 
-    _function = 
AggregateFunctionSimpleFactory::instance().get(_fn.name.function_name, 
argument_types,
-                                                               params, 
_data_type->is_nullable());
+    _function = AggregateFunctionSimpleFactory::instance().get(
+            _fn.name.function_name, _argument_types.empty() ? 
tmp_argument_types : _argument_types,
+            {}, _data_type->is_nullable());
     if (_function == nullptr) {
         return Status::InternalError(
                 fmt::format("Agg Function {} is not implemented", 
_fn.name.function_name));
diff --git a/be/src/vec/exprs/vectorized_agg_fn.h 
b/be/src/vec/exprs/vectorized_agg_fn.h
index 0f1f145ced..ea130c94de 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.h
+++ b/be/src/vec/exprs/vectorized_agg_fn.h
@@ -79,7 +79,7 @@ private:
     void _calc_argment_columns(Block* block);
 
     const TypeDescriptor _return_type;
-    const TypeDescriptor _intermediate_type;
+    DataTypes _argument_types;
 
     const SlotDescriptor* _intermediate_slot_desc;
     const SlotDescriptor* _output_slot_desc;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
index 128e60df7a..a82f2a7071 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
@@ -19,7 +19,6 @@ package org.apache.doris.analysis;
 
 import org.apache.doris.catalog.FunctionSet;
 import org.apache.doris.common.AnalysisException;
-import org.apache.doris.common.util.VectorizedUtil;
 import org.apache.doris.planner.DataPartition;
 import org.apache.doris.thrift.TPartitionType;
 
@@ -459,17 +458,10 @@ public final class AggregateInfo extends 
AggregateInfoBase {
         for (int i = 0; i < getAggregateExprs().size(); ++i) {
             FunctionCallExpr inputExpr = getAggregateExprs().get(i);
             Preconditions.checkState(inputExpr.isAggregateFunction());
-            List<Expr> paramExprs = new ArrayList<>();
-            // TODO(zhannngchen), change intermediate argument to a list, and 
remove this
-            // ad-hoc logic
-            if ((inputExpr.fn.functionName().equals("max_by") ||
-                    inputExpr.fn.functionName().equals("min_by")) && 
VectorizedUtil.isVectorized()) {
-                paramExprs.addAll(inputExpr.getFnParams().exprs());
-            } else {
-                paramExprs.add(new SlotRef(inputDesc.getSlots().get(i + 
getGroupingExprs().size())));
-            }
+            Expr aggExprParam = new SlotRef(inputDesc.getSlots().get(i + 
getGroupingExprs().size()));
             FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
-                    inputExpr, paramExprs);
+            inputExpr, Lists.newArrayList(aggExprParam), 
inputExpr.getFnParams().exprs());
+
             aggExpr.analyzeNoThrow(analyzer);
             aggExprs.add(aggExpr);
         }
@@ -586,11 +578,11 @@ public final class AggregateInfo extends 
AggregateInfoBase {
         for (int i = 0; i < aggregateExprs_.size(); ++i) {
             FunctionCallExpr inputExpr = aggregateExprs_.get(i);
             Preconditions.checkState(inputExpr.isAggregateFunction());
-            // we're aggregating an output slot of the 1st agg phase
             Expr aggExprParam =
                     new SlotRef(inputDesc.getSlots().get(i + 
getGroupingExprs().size()));
+
             FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
-                    inputExpr, Lists.newArrayList(aggExprParam));
+                inputExpr, Lists.newArrayList(aggExprParam), 
inputExpr.getFnParams().exprs());
             secondPhaseAggExprs.add(aggExpr);
         }
         Preconditions.checkState(
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index 0021a1cf9b..58bca1cbf3 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -33,7 +33,6 @@ import org.apache.doris.common.ErrorReport;
 import org.apache.doris.common.util.VectorizedUtil;
 import org.apache.doris.mysql.privilege.PrivPredicate;
 import org.apache.doris.qe.ConnectContext;
-import org.apache.doris.thrift.TAggregateExpr;
 import org.apache.doris.thrift.TExprNode;
 import org.apache.doris.thrift.TExprNodeType;
 
@@ -66,6 +65,9 @@ public class FunctionCallExpr extends Expr {
     // private BuiltinAggregateFunction.Operator aggOp;
     private FunctionParams fnParams;
 
+    // represent original parament from aggregate function
+    private FunctionParams aggFnParams;
+
     // check analytic function
     private boolean isAnalyticFnCall = false;
     // check table function
@@ -89,6 +91,10 @@ public class FunctionCallExpr extends Expr {
 
     private boolean isRewrote = false;
 
+    public void setAggFnParams(FunctionParams aggFnParams) {
+        this.aggFnParams = aggFnParams;
+    }
+
     public void setIsAnalyticFnCall(boolean v) {
         isAnalyticFnCall = v;
     }
@@ -150,6 +156,7 @@ public class FunctionCallExpr extends Expr {
         // aggOp = e.aggOp;
         isAnalyticFnCall = e.isAnalyticFnCall;
         fnParams = params;
+        aggFnParams = e.aggFnParams;
         // Just inherit the function object from 'e'.
         fn = e.fn;
         this.isMergeAggFn = e.isMergeAggFn;
@@ -172,6 +179,7 @@ public class FunctionCallExpr extends Expr {
         } else {
             fnParams = new FunctionParams(other.fnParams.isDistinct(), 
children);
         }
+        aggFnParams = other.aggFnParams;
         this.isMergeAggFn = other.isMergeAggFn;
         fn = other.fn;
         this.isTableFnCall = other.isTableFnCall;
@@ -354,7 +362,10 @@ public class FunctionCallExpr extends Expr {
         if (isAggregate() || isAnalyticFnCall) {
             msg.node_type = TExprNodeType.AGG_EXPR;
             if (!isAnalyticFnCall) {
-                msg.setAggExpr(new TAggregateExpr(isMergeAggFn));
+                if (aggFnParams == null) {
+                    aggFnParams = fnParams;
+                }
+                msg.setAggExpr(aggFnParams.createTAggregateExpr(isMergeAggFn));
             }
         } else {
             msg.node_type = TExprNodeType.FUNCTION_CALL;
@@ -1041,14 +1052,15 @@ public class FunctionCallExpr extends Expr {
     }
 
     public static FunctionCallExpr createMergeAggCall(
-            FunctionCallExpr agg, List<Expr> params) {
+            FunctionCallExpr agg, List<Expr> intermediateParams, List<Expr> 
realParams) {
         Preconditions.checkState(agg.isAnalyzed);
         Preconditions.checkState(agg.isAggregateFunction());
         FunctionCallExpr result = new FunctionCallExpr(
-                agg.fnName, new FunctionParams(false, params), true);
+                agg.fnName, new FunctionParams(false, intermediateParams), 
true);
         // Inherit the function object from 'agg'.
         result.fn = agg.fn;
         result.type = agg.type;
+        result.setAggFnParams(new FunctionParams(false, realParams));
         return result;
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
index 59d85ace26..742cbf4b4c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
@@ -18,12 +18,15 @@
 package org.apache.doris.analysis;
 
 import org.apache.doris.common.io.Writable;
+import org.apache.doris.thrift.TAggregateExpr;
+import org.apache.doris.thrift.TTypeDesc;
 
 import com.google.common.collect.Lists;
 
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 
@@ -118,4 +121,18 @@ public class FunctionParams implements Writable {
         }
         return result;
     }
+
+    public TAggregateExpr createTAggregateExpr(boolean isMergeAggFn) {
+        List<TTypeDesc> paramTypes = new ArrayList<TTypeDesc>();
+        if (exprs != null) {
+            for (Expr expr : exprs) {
+                TTypeDesc desc = expr.getType().toThrift();
+                desc.setIsNullable(expr.isNullable());
+                paramTypes.add(desc);
+            }
+        }
+        TAggregateExpr aggExpr = new TAggregateExpr(isMergeAggFn);
+        aggExpr.setParamTypes(paramTypes);
+        return aggExpr;
+    }
 }
diff --git a/gensrc/thrift/Exprs.thrift b/gensrc/thrift/Exprs.thrift
index 450148f381..df44b5ae60 100644
--- a/gensrc/thrift/Exprs.thrift
+++ b/gensrc/thrift/Exprs.thrift
@@ -73,6 +73,7 @@ enum TExprNodeType {
 struct TAggregateExpr {
   // Indicates whether this expr is the merge() of an aggregation.
   1: required bool is_merge_agg
+  2: optional list<Types.TTypeDesc> param_types
 }
 struct TBoolLiteral {
   1: required bool value
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index 9a232915ee..ae5b5e1853 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -135,6 +135,7 @@ struct TTypeNode {
 // to TTypeDesc. In future, we merge these two to one
 struct TTypeDesc {
     1: list<TTypeNode> types
+    2: optional bool is_nullable
 }
 
 enum TAggregationType {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to