This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch dev-1.1.2 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/dev-1.1.2 by this push: new ccf2f82be4 [fix](function)fix max_by function bug (#11745) ccf2f82be4 is described below commit ccf2f82be43d9170cb0c3373bd27338573e6d7c9 Author: starocean999 <40539150+starocean...@users.noreply.github.com> AuthorDate: Mon Aug 15 09:06:54 2022 +0800 [fix](function)fix max_by function bug (#11745) This pr does the same thing as #10650. Because the code base is so different that it's easier to make the changes based on dev-1.1.2 than cherry-pick --- be/src/vec/CMakeLists.txt | 1 + be/src/vec/data_types/data_type_factory.cpp | 100 +++++++++++++++++++++ be/src/vec/data_types/data_type_factory.hpp | 8 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 28 +++--- be/src/vec/exprs/vectorized_agg_fn.h | 2 +- .../org/apache/doris/analysis/AggregateInfo.java | 18 ++-- .../apache/doris/analysis/FunctionCallExpr.java | 20 ++++- .../org/apache/doris/analysis/FunctionParams.java | 17 ++++ gensrc/thrift/Exprs.thrift | 1 + gensrc/thrift/Types.thrift | 1 + 10 files changed, 166 insertions(+), 30 deletions(-) diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index d8adfe6d40..56cfdfcc14 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -73,6 +73,7 @@ set(VEC_FILES data_types/nested_utils.cpp data_types/data_type_date.cpp data_types/data_type_date_time.cpp + data_types/data_type_factory.cpp exec/vaggregation_node.cpp exec/ves_http_scan_node.cpp exec/ves_http_scanner.cpp diff --git a/be/src/vec/data_types/data_type_factory.cpp b/be/src/vec/data_types/data_type_factory.cpp new file mode 100644 index 0000000000..d78679d377 --- /dev/null +++ b/be/src/vec/data_types/data_type_factory.cpp @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/DataTypes/DataTypeFactory.cpp +// and modified by Doris + +#include "vec/data_types/data_type_factory.hpp" +#include "runtime/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_bitmap.h" +#include "vec/data_types/data_type_date.h" +#include "vec/data_types/data_type_date_time.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_nothing.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" + +namespace doris::vectorized { + +DataTypePtr DataTypeFactory::create_data_type(const TypeDescriptor& col_desc, bool is_nullable) { + DataTypePtr nested = nullptr; + switch (col_desc.type) { + case TYPE_BOOLEAN: + nested = std::make_shared<vectorized::DataTypeUInt8>(); + break; + case TYPE_TINYINT: + nested = std::make_shared<vectorized::DataTypeInt8>(); + break; + case TYPE_SMALLINT: + nested = std::make_shared<vectorized::DataTypeInt16>(); + break; + case TYPE_INT: + nested = std::make_shared<vectorized::DataTypeInt32>(); + break; + case TYPE_FLOAT: + nested = std::make_shared<vectorized::DataTypeFloat32>(); + break; + case TYPE_BIGINT: + nested = std::make_shared<vectorized::DataTypeInt64>(); + break; + case TYPE_LARGEINT: + nested = std::make_shared<vectorized::DataTypeInt128>(); + break; + case TYPE_DATE: + nested = std::make_shared<vectorized::DataTypeDate>(); + break; + case TYPE_DATETIME: + nested = std::make_shared<vectorized::DataTypeDateTime>(); + break; + case TYPE_TIME: + case TYPE_DOUBLE: + nested = std::make_shared<vectorized::DataTypeFloat64>(); + break; + case TYPE_STRING: + case TYPE_CHAR: + case TYPE_VARCHAR: + nested = std::make_shared<vectorized::DataTypeString>(); + break; + case TYPE_HLL: + nested = std::make_shared<vectorized::DataTypeHLL>(); + break; + case TYPE_OBJECT: + nested = std::make_shared<vectorized::DataTypeBitMap>(); + break; + case TYPE_DECIMALV2: + nested = std::make_shared<vectorized::DataTypeDecimal<vectorized::Decimal128>>(27, 9); + break; + // Just Mock A NULL Type in Vec Exec Engine + case TYPE_NULL: + nested = std::make_shared<vectorized::DataTypeUInt8>(); + break; + case INVALID_TYPE: + default: + DCHECK(false) << "invalid PrimitiveType:" << (int)col_desc.type; + break; + } + + if (nested && is_nullable) { + return std::make_shared<vectorized::DataTypeNullable>(nested); + } + return nested; +} + + +} // namespace doris::vectorized diff --git a/be/src/vec/data_types/data_type_factory.hpp b/be/src/vec/data_types/data_type_factory.hpp index e06a962c2f..abe84cfe52 100644 --- a/be/src/vec/data_types/data_type_factory.hpp +++ b/be/src/vec/data_types/data_type_factory.hpp @@ -21,7 +21,7 @@ #pragma once #include <mutex> #include <string> - +#include "runtime/types.h" #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_date.h" #include "vec/data_types/data_type_date_time.h" @@ -74,6 +74,12 @@ public: return _empty_string; } + DataTypePtr create_data_type(const TypeDescriptor& col_desc, bool is_nullable = true); + + DataTypePtr create_data_type(const TTypeDesc& raw_type) { + return create_data_type(TypeDescriptor::from_thrift(raw_type), raw_type.is_nullable); + } + private: void regist_data_type(const std::string& name, const DataTypePtr& data_type) { _data_type_map.emplace(name, data_type); diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index 0987190069..cab89ae7ff 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -23,6 +23,7 @@ #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/columns/column_nullable.h" #include "vec/core/materialize_block.h" +#include "vec/data_types/data_type_factory.hpp" #include "vec/data_types/data_type_nullable.h" #include "vec/exprs/vexpr.h" @@ -32,18 +33,23 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) : _fn(desc.fn), _is_merge(desc.agg_expr.is_merge_agg), _return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)), - _intermediate_type(TypeDescriptor::from_thrift(desc.fn.aggregate_fn.intermediate_type)), _intermediate_slot_desc(nullptr), _output_slot_desc(nullptr), _exec_timer(nullptr), _merge_timer(nullptr), _expr_timer(nullptr) { - if (desc.__isset.is_nullable) { - _data_type = IDataType::from_thrift(_return_type.type, desc.is_nullable); - } else { - _data_type = IDataType::from_thrift(_return_type.type); + if (desc.__isset.is_nullable) { + _data_type = IDataType::from_thrift(_return_type.type, desc.is_nullable); + } else { + _data_type = IDataType::from_thrift(_return_type.type); + } + if (desc.agg_expr.__isset.param_types) { + auto& param_types = desc.agg_expr.param_types; + for (auto raw_type : param_types) { + _argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type)); } } +} Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluator** result) { *result = pool->add(new AggFnEvaluator(desc.nodes[0])); @@ -73,21 +79,21 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, M Status status = VExpr::prepare(_input_exprs_ctxs, state, desc, mem_tracker); RETURN_IF_ERROR(status); - DataTypes argument_types; - argument_types.reserve(_input_exprs_ctxs.size()); + DataTypes tmp_argument_types; + tmp_argument_types.reserve(_input_exprs_ctxs.size()); std::vector<std::string_view> child_expr_name; - doris::vectorized::Array params; // prepare for argument for (int i = 0; i < _input_exprs_ctxs.size(); ++i) { auto data_type = _input_exprs_ctxs[i]->root()->data_type(); - argument_types.emplace_back(data_type); + tmp_argument_types.emplace_back(data_type); child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name()); } - _function = AggregateFunctionSimpleFactory::instance().get(_fn.name.function_name, argument_types, - params, _data_type->is_nullable()); + _function = AggregateFunctionSimpleFactory::instance().get( + _fn.name.function_name, _argument_types.empty() ? tmp_argument_types : _argument_types, + {}, _data_type->is_nullable()); if (_function == nullptr) { return Status::InternalError( fmt::format("Agg Function {} is not implemented", _fn.name.function_name)); diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h index 0f1f145ced..ea130c94de 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.h +++ b/be/src/vec/exprs/vectorized_agg_fn.h @@ -79,7 +79,7 @@ private: void _calc_argment_columns(Block* block); const TypeDescriptor _return_type; - const TypeDescriptor _intermediate_type; + DataTypes _argument_types; const SlotDescriptor* _intermediate_slot_desc; const SlotDescriptor* _output_slot_desc; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java index 128e60df7a..a82f2a7071 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java @@ -19,7 +19,6 @@ package org.apache.doris.analysis; import org.apache.doris.catalog.FunctionSet; import org.apache.doris.common.AnalysisException; -import org.apache.doris.common.util.VectorizedUtil; import org.apache.doris.planner.DataPartition; import org.apache.doris.thrift.TPartitionType; @@ -459,17 +458,10 @@ public final class AggregateInfo extends AggregateInfoBase { for (int i = 0; i < getAggregateExprs().size(); ++i) { FunctionCallExpr inputExpr = getAggregateExprs().get(i); Preconditions.checkState(inputExpr.isAggregateFunction()); - List<Expr> paramExprs = new ArrayList<>(); - // TODO(zhannngchen), change intermediate argument to a list, and remove this - // ad-hoc logic - if ((inputExpr.fn.functionName().equals("max_by") || - inputExpr.fn.functionName().equals("min_by")) && VectorizedUtil.isVectorized()) { - paramExprs.addAll(inputExpr.getFnParams().exprs()); - } else { - paramExprs.add(new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size()))); - } + Expr aggExprParam = new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size())); FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall( - inputExpr, paramExprs); + inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs()); + aggExpr.analyzeNoThrow(analyzer); aggExprs.add(aggExpr); } @@ -586,11 +578,11 @@ public final class AggregateInfo extends AggregateInfoBase { for (int i = 0; i < aggregateExprs_.size(); ++i) { FunctionCallExpr inputExpr = aggregateExprs_.get(i); Preconditions.checkState(inputExpr.isAggregateFunction()); - // we're aggregating an output slot of the 1st agg phase Expr aggExprParam = new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size())); + FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall( - inputExpr, Lists.newArrayList(aggExprParam)); + inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs()); secondPhaseAggExprs.add(aggExpr); } Preconditions.checkState( diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 0021a1cf9b..58bca1cbf3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -33,7 +33,6 @@ import org.apache.doris.common.ErrorReport; import org.apache.doris.common.util.VectorizedUtil; import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.qe.ConnectContext; -import org.apache.doris.thrift.TAggregateExpr; import org.apache.doris.thrift.TExprNode; import org.apache.doris.thrift.TExprNodeType; @@ -66,6 +65,9 @@ public class FunctionCallExpr extends Expr { // private BuiltinAggregateFunction.Operator aggOp; private FunctionParams fnParams; + // represent original parament from aggregate function + private FunctionParams aggFnParams; + // check analytic function private boolean isAnalyticFnCall = false; // check table function @@ -89,6 +91,10 @@ public class FunctionCallExpr extends Expr { private boolean isRewrote = false; + public void setAggFnParams(FunctionParams aggFnParams) { + this.aggFnParams = aggFnParams; + } + public void setIsAnalyticFnCall(boolean v) { isAnalyticFnCall = v; } @@ -150,6 +156,7 @@ public class FunctionCallExpr extends Expr { // aggOp = e.aggOp; isAnalyticFnCall = e.isAnalyticFnCall; fnParams = params; + aggFnParams = e.aggFnParams; // Just inherit the function object from 'e'. fn = e.fn; this.isMergeAggFn = e.isMergeAggFn; @@ -172,6 +179,7 @@ public class FunctionCallExpr extends Expr { } else { fnParams = new FunctionParams(other.fnParams.isDistinct(), children); } + aggFnParams = other.aggFnParams; this.isMergeAggFn = other.isMergeAggFn; fn = other.fn; this.isTableFnCall = other.isTableFnCall; @@ -354,7 +362,10 @@ public class FunctionCallExpr extends Expr { if (isAggregate() || isAnalyticFnCall) { msg.node_type = TExprNodeType.AGG_EXPR; if (!isAnalyticFnCall) { - msg.setAggExpr(new TAggregateExpr(isMergeAggFn)); + if (aggFnParams == null) { + aggFnParams = fnParams; + } + msg.setAggExpr(aggFnParams.createTAggregateExpr(isMergeAggFn)); } } else { msg.node_type = TExprNodeType.FUNCTION_CALL; @@ -1041,14 +1052,15 @@ public class FunctionCallExpr extends Expr { } public static FunctionCallExpr createMergeAggCall( - FunctionCallExpr agg, List<Expr> params) { + FunctionCallExpr agg, List<Expr> intermediateParams, List<Expr> realParams) { Preconditions.checkState(agg.isAnalyzed); Preconditions.checkState(agg.isAggregateFunction()); FunctionCallExpr result = new FunctionCallExpr( - agg.fnName, new FunctionParams(false, params), true); + agg.fnName, new FunctionParams(false, intermediateParams), true); // Inherit the function object from 'agg'. result.fn = agg.fn; result.type = agg.type; + result.setAggFnParams(new FunctionParams(false, realParams)); return result; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java index 59d85ace26..742cbf4b4c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java @@ -18,12 +18,15 @@ package org.apache.doris.analysis; import org.apache.doris.common.io.Writable; +import org.apache.doris.thrift.TAggregateExpr; +import org.apache.doris.thrift.TTypeDesc; import com.google.common.collect.Lists; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -118,4 +121,18 @@ public class FunctionParams implements Writable { } return result; } + + public TAggregateExpr createTAggregateExpr(boolean isMergeAggFn) { + List<TTypeDesc> paramTypes = new ArrayList<TTypeDesc>(); + if (exprs != null) { + for (Expr expr : exprs) { + TTypeDesc desc = expr.getType().toThrift(); + desc.setIsNullable(expr.isNullable()); + paramTypes.add(desc); + } + } + TAggregateExpr aggExpr = new TAggregateExpr(isMergeAggFn); + aggExpr.setParamTypes(paramTypes); + return aggExpr; + } } diff --git a/gensrc/thrift/Exprs.thrift b/gensrc/thrift/Exprs.thrift index 450148f381..df44b5ae60 100644 --- a/gensrc/thrift/Exprs.thrift +++ b/gensrc/thrift/Exprs.thrift @@ -73,6 +73,7 @@ enum TExprNodeType { struct TAggregateExpr { // Indicates whether this expr is the merge() of an aggregation. 1: required bool is_merge_agg + 2: optional list<Types.TTypeDesc> param_types } struct TBoolLiteral { 1: required bool value diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index 9a232915ee..ae5b5e1853 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -135,6 +135,7 @@ struct TTypeNode { // to TTypeDesc. In future, we merge these two to one struct TTypeDesc { 1: list<TTypeNode> types + 2: optional bool is_nullable } enum TAggregationType { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org