This is an automated email from the ASF dual-hosted git repository. lihaopeng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push: new 143c408 [Feature][Vectorized] support aggregate function ndv()/approx_count_distinct() (#8044) 143c408 is described below commit 143c4085ee58007954f3eef8910556f5b8ce6b39 Author: Pxl <952130...@qq.com> AuthorDate: Wed Feb 16 14:30:13 2022 +0800 [Feature][Vectorized] support aggregate function ndv()/approx_count_distinct() (#8044) --- be/src/vec/CMakeLists.txt | 1 + .../aggregate_function_approx_count_distinct.cpp | 50 ++++++++++ .../aggregate_function_approx_count_distinct.h | 107 +++++++++++++++++++++ .../aggregate_functions/aggregate_function_avg.cpp | 4 - .../aggregate_function_simple_factory.cpp | 5 + .../aggregate_function_simple_factory.h | 7 +- be/src/vec/common/string_ref.h | 6 +- be/src/vec/functions/function_case.h | 55 ++--------- be/src/vec/functions/function_coalesce.cpp | 102 ++++++++------------ be/src/vec/functions/function_hash.cpp | 80 ++++++--------- be/src/vec/utils/template_helpers.hpp | 69 +++++++++++++ .../java/org/apache/doris/catalog/FunctionSet.java | 54 +++++++---- 12 files changed, 352 insertions(+), 188 deletions(-) diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index c526cf2..7f8878f 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -33,6 +33,7 @@ set(VEC_FILES aggregate_functions/aggregate_function_window.cpp aggregate_functions/aggregate_function_stddev.cpp aggregate_functions/aggregate_function_topn.cpp + aggregate_functions/aggregate_function_approx_count_distinct.cpp aggregate_functions/aggregate_function_simple_factory.cpp columns/collator.cpp columns/column.cpp diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp new file mode 100644 index 0000000..fc68d85 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/aggregate_functions/aggregate_function_approx_count_distinct.h" + +#include "vec/utils/template_helpers.hpp" + +namespace doris::vectorized { + +AggregateFunctionPtr create_aggregate_function_approx_count_distinct( + const std::string& name, const DataTypes& argument_types, const Array& parameters, + const bool result_is_nullable) { + AggregateFunctionPtr res = nullptr; + WhichDataType which(argument_types[0]->is_nullable() + ? reinterpret_cast<const DataTypeNullable*>(argument_types[0].get()) + ->get_nested_type() + : argument_types[0]); + + res.reset(create_class_with_type<AggregateFunctionApproxCountDistinct>(*argument_types[0], + argument_types)); + + if (!res) { + LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate function {}", + argument_types[0]->get_name(), name); + } + + return res; +} + +void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory) { + factory.register_function("approx_count_distinct", + create_aggregate_function_approx_count_distinct); + factory.register_alias("approx_count_distinct", "ndv"); +} + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h new file mode 100644 index 0000000..ed393af --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "exprs/anyval_util.h" +#include "olap/hll.h" +#include "udf/udf.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/common/string_ref.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +struct AggregateFunctionApproxCountDistinctData { + HyperLogLog hll_data; + + void add(StringRef value) { + StringVal sv = value.to_string_val(); + uint64_t hash_value = AnyValUtil::hash64_murmur(sv, HashUtil::MURMUR_SEED); + if (hash_value != 0) { + hll_data.update(hash_value); + } + } + + void merge(const AggregateFunctionApproxCountDistinctData& rhs) { + hll_data.merge(rhs.hll_data); + } + + void write(BufferWritable& buf) const { + std::string result; + result.resize(hll_data.max_serialized_size()); + int size = hll_data.serialize((uint8_t*)result.data()); + result.resize(size); + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + StringRef result; + read_binary(result, buf); + Slice data = Slice(result.data, result.size); + hll_data.deserialize(data); + } + + int64_t get() const { return hll_data.estimate_cardinality(); } + + void reset() { hll_data.clear(); } +}; + +template <typename ColumnDataType> +class AggregateFunctionApproxCountDistinct final + : public IAggregateFunctionDataHelper< + AggregateFunctionApproxCountDistinctData, + AggregateFunctionApproxCountDistinct<ColumnDataType>> { +public: + String get_name() const override { return "approx_count_distinct"; } + + AggregateFunctionApproxCountDistinct(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<AggregateFunctionApproxCountDistinctData, + AggregateFunctionApproxCountDistinct<ColumnDataType>>( + argument_types_, {}) {} + + DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena*) const override { + this->data(place).add(static_cast<const ColumnDataType*>(columns[0])->get_data_at(row_num)); + } + + void reset(AggregateDataPtr place) const override { this->data(place).reset(); } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + auto& column = static_cast<ColumnInt64&>(to); + column.get_data().push_back(this->data(place).get()); + } +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index bb7605b..61687af 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -27,8 +27,6 @@ namespace doris::vectorized { -namespace { - template <typename T> struct Avg { using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>; @@ -60,8 +58,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, return res; } -} // namespace - void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { factory.register_function("avg", create_aggregate_function_avg); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 4844000..87a52b9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -23,7 +23,9 @@ #include "vec/aggregate_functions/aggregate_function_reader.h" namespace doris::vectorized { + class AggregateFunctionSimpleFactory; + void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory); @@ -37,6 +39,8 @@ void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& fac void register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory); + AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { static std::once_flag oc; static AggregateFunctionSimpleFactory instance; @@ -53,6 +57,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_window_rank(instance); register_aggregate_function_stddev_variance(instance); register_aggregate_function_topn(instance); + register_aggregate_function_approx_count_distinct(instance); // if you only register function with no nullable, and wants to add nullable automatically, you should place function above this line register_aggregate_function_combinator_null(instance); diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index 1bac4f1..833e52d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -46,6 +46,7 @@ private: AggregateFunctions aggregate_functions; AggregateFunctions nullable_aggregate_functions; std::unordered_map<std::string, std::string> function_alias; + public: void register_nullable_function_combinator(const Creator& creator) { for (const auto& entity : aggregate_functions) { @@ -86,13 +87,13 @@ public: if (nullable) { return nullable_aggregate_functions.find(name_str) == nullable_aggregate_functions.end() ? nullptr - : nullable_aggregate_functions[name_str](name_str, argument_types, parameters, - result_is_nullable); + : nullable_aggregate_functions[name_str](name_str, argument_types, + parameters, result_is_nullable); } else { return aggregate_functions.find(name_str) == aggregate_functions.end() ? nullptr : aggregate_functions[name_str](name_str, argument_types, parameters, - result_is_nullable); + result_is_nullable); } } diff --git a/be/src/vec/common/string_ref.h b/be/src/vec/common/string_ref.h index 727996e..5dd146e 100644 --- a/be/src/vec/common/string_ref.h +++ b/be/src/vec/common/string_ref.h @@ -55,9 +55,13 @@ struct StringRef { explicit operator std::string() const { return to_string(); } - StringVal to_string_val() const { + StringVal to_string_val() { return StringVal(reinterpret_cast<uint8_t*>(const_cast<char*>(data)), size); } + + static StringRef from_string_val(StringVal sv) { + return StringRef(reinterpret_cast<char*>(sv.ptr), sv.len); + } }; using StringRefs = std::vector<StringRef>; diff --git a/be/src/vec/functions/function_case.h b/be/src/vec/functions/function_case.h index 1113b5f..1b728e9 100644 --- a/be/src/vec/functions/function_case.h +++ b/be/src/vec/functions/function_case.h @@ -17,11 +17,11 @@ #pragma once -#include "vec/data_types/data_type_decimal.h" #include "vec/data_types/data_type_nullable.h" #include "vec/functions/function.h" #include "vec/functions/function_helpers.h" #include "vec/functions/simple_function_factory.h" +#include "vec/utils/template_helpers.hpp" namespace doris::vectorized { @@ -311,51 +311,14 @@ public: ? reinterpret_cast<const DataTypeNullable*>(data_type.get()) ->get_nested_type() : data_type); - - // TODO: use template traits here. - if (which.is_uint8()) { - return execute_get_when_null<ColumnUInt8>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_int16()) { - return execute_get_when_null<ColumnInt16>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_uint32()) { - return execute_get_when_null<ColumnUInt32>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_uint64()) { - return execute_get_when_null<ColumnUInt64>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_int8()) { - return execute_get_when_null<ColumnInt8>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_int16()) { - return execute_get_when_null<ColumnInt16>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_int32()) { - return execute_get_when_null<ColumnInt32>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_int64()) { - return execute_get_when_null<ColumnInt64>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_date_or_datetime()) { - return execute_get_when_null<ColumnVector<DateTime>>(data_type, block, arguments, - result, input_rows_count); - } else if (which.is_float32()) { - return execute_get_when_null<ColumnFloat32>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_float64()) { - return execute_get_when_null<ColumnFloat64>(data_type, block, arguments, result, - input_rows_count); - } else if (which.is_decimal()) { - return execute_get_when_null<ColumnDecimal<Decimal128>>(data_type, block, arguments, - result, input_rows_count); - } else if (which.is_string()) { - return execute_get_when_null<ColumnString>(data_type, block, arguments, result, - input_rows_count); - } else { - return Status::NotSupported(fmt::format("Unexpected type {} of argument of function {}", - data_type->get_name(), get_name())); - } +#define DISPATCH(TYPE, COLUMN_TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return execute_get_when_null<COLUMN_TYPE>(data_type, block, arguments, result, \ + input_rows_count); + TYPE_TO_COLUMN_TYPE(DISPATCH) +#undef DISPATCH + return Status::NotSupported( + fmt::format("argument_type {} not supported", data_type->get_name())); } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, diff --git a/be/src/vec/functions/function_coalesce.cpp b/be/src/vec/functions/function_coalesce.cpp index 4991fa8..91d6304 100644 --- a/be/src/vec/functions/function_coalesce.cpp +++ b/be/src/vec/functions/function_coalesce.cpp @@ -16,11 +16,10 @@ // under the License. #include "udf/udf.h" -#include "vec/data_types/data_type_nothing.h" -#include "vec/data_types/data_type_number.h" #include "vec/data_types/get_least_supertype.h" #include "vec/functions/function_helpers.h" #include "vec/functions/simple_function_factory.h" +#include "vec/utils/template_helpers.hpp" #include "vec/utils/util.hpp" namespace doris::vectorized { @@ -53,11 +52,9 @@ public: res = res ? res : arguments[0]; - const ColumnsWithTypeAndName is_not_null_col{ - {nullptr, make_nullable(res), ""} - }; - func_is_not_null = SimpleFunctionFactory::instance(). - get_function("is_not_null_pred", is_not_null_col, std::make_shared<DataTypeUInt8>()); + const ColumnsWithTypeAndName is_not_null_col {{nullptr, make_nullable(res), ""}}; + func_is_not_null = SimpleFunctionFactory::instance().get_function( + "is_not_null_pred", is_not_null_col, std::make_shared<DataTypeUInt8>()); return res; } @@ -74,7 +71,8 @@ public: filtered_args.push_back(arguments[i]); if (!arg_type->is_nullable()) { if (i == 0) { //if the first column not null, return it's directly - block.get_by_position(result).column = block.get_by_position(arguments[0]).column; + block.get_by_position(result).column = + block.get_by_position(arguments[0]).column; return Status::OK(); } else { break; @@ -84,8 +82,12 @@ public: size_t remaining_rows = input_rows_count; size_t argument_size = filtered_args.size(); - std::vector<uint32_t> record_idx(input_rows_count, 0); //used to save column idx, record the result data of each row from which column - std::vector<uint8_t> filled_flags(input_rows_count, 0); //used to save filled flag, in order to check current row whether have filled data + std::vector<uint32_t> record_idx( + input_rows_count, + 0); //used to save column idx, record the result data of each row from which column + std::vector<uint8_t> filled_flags( + input_rows_count, + 0); //used to save filled flag, in order to check current row whether have filled data MutableColumnPtr result_column; if (!result_type->is_nullable()) { @@ -104,7 +106,8 @@ public: } auto return_type = std::make_shared<DataTypeUInt8>(); - auto null_map = ColumnUInt8::create(input_rows_count, 1); //if null_map_data==1, the current row should be null + auto null_map = ColumnUInt8::create( + input_rows_count, 1); //if null_map_data==1, the current row should be null auto* __restrict null_map_data = null_map->get_data().data(); ColumnPtr argument_columns[argument_size]; //use to save nested_column if is nullable column @@ -119,17 +122,17 @@ public: } Block temporary_block { - ColumnsWithTypeAndName { - block.get_by_position(filtered_args[0]), - {nullptr, std::make_shared<DataTypeUInt8>(), ""} - } - }; + ColumnsWithTypeAndName {block.get_by_position(filtered_args[0]), + {nullptr, std::make_shared<DataTypeUInt8>(), ""}}}; for (size_t i = 0; i < argument_size && remaining_rows; ++i) { - temporary_block.get_by_position(0).column = block.get_by_position(filtered_args[i]).column; + temporary_block.get_by_position(0).column = + block.get_by_position(filtered_args[i]).column; func_is_not_null->execute(context, temporary_block, {0}, 1, input_rows_count); - auto res_column = (*temporary_block.get_by_position(1).column->convert_to_full_column_if_const()).mutate(); + auto res_column = + (*temporary_block.get_by_position(1).column->convert_to_full_column_if_const()) + .mutate(); auto& res_map = assert_cast<ColumnVector<UInt8>*>(res_column.get())->get_data(); auto* __restrict res = res_map.data(); @@ -152,7 +155,8 @@ public: if (is_same_column_count == input_rows_count) { if (result_type->is_nullable()) { - block.get_by_position(result).column = make_nullable(argument_columns[i], false); + block.get_by_position(result).column = + make_nullable(argument_columns[i], false); } else { block.get_by_position(result).column = argument_columns[i]; } @@ -170,7 +174,7 @@ public: } if (is_string_result) { - //if string type, should according to the record results, fill in result one by one, + //if string type, should according to the record results, fill in result one by one, for (size_t row = 0; row < input_rows_count; ++row) { if (null_map_data[row]) { //should be null result_column->insert_default(); @@ -181,7 +185,8 @@ public: } if (result_type->is_nullable()) { - block.replace_by_position(result, ColumnNullable::create(std::move(result_column), std::move(null_map))); + block.replace_by_position( + result, ColumnNullable::create(std::move(result_column), std::move(null_map))); } else { block.replace_by_position(result, std::move(result_column)); } @@ -198,18 +203,17 @@ public: auto* __restrict column_raw_data = reinterpret_cast<const ColumnType*>(argument_column.get())->get_data().data(); - // Here it's SIMD thought the compiler automatically also // true: null_map_data[row]==0 && filled_idx[row]==0 // if true, could filled current row data into result column for (size_t row = 0; row < input_rows_count; ++row) { - result_raw_data[row] += (!(null_map_data[row] | filled_flag[row])) * column_raw_data[row]; + result_raw_data[row] += + (!(null_map_data[row] | filled_flag[row])) * column_raw_data[row]; filled_flag[row] += (!(null_map_data[row] | filled_flag[row])); } return Status::OK(); } - //TODO: this function is same as case when, should be replaced by macro Status filled_result_column(const DataTypePtr& data_type, MutableColumnPtr& result_column, ColumnPtr& argument_column, UInt8* __restrict null_map_data, UInt8* __restrict filled_flag, const size_t input_rows_count) { @@ -217,46 +221,16 @@ public: ? reinterpret_cast<const DataTypeNullable*>(data_type.get()) ->get_nested_type() : data_type); - if (which.is_uint8()) { - return insert_result_data<ColumnUInt8>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_int16()) { - return insert_result_data<ColumnInt16>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_uint32()) { - return insert_result_data<ColumnUInt32>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_uint64()) { - return insert_result_data<ColumnUInt64>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_int8()) { - return insert_result_data<ColumnInt8>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_int16()) { - return insert_result_data<ColumnInt16>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_int32()) { - return insert_result_data<ColumnInt32>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_int64()) { - return insert_result_data<ColumnInt64>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_date_or_datetime()) { - return insert_result_data<ColumnVector<DateTime>>( - result_column, argument_column, null_map_data, filled_flag, input_rows_count); - } else if (which.is_float32()) { - return insert_result_data<ColumnFloat32>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_float64()) { - return insert_result_data<ColumnFloat64>(result_column, argument_column, null_map_data, - filled_flag, input_rows_count); - } else if (which.is_decimal()) { - return insert_result_data<ColumnDecimal<Decimal128>>( - result_column, argument_column, null_map_data, filled_flag, input_rows_count); - } else { - return Status::NotSupported(fmt::format("Unexpected type {} of argument of function {}", - data_type->get_name(), get_name())); - } +#define DISPATCH(TYPE, COLUMN_TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return insert_result_data<COLUMN_TYPE>(result_column, argument_column, null_map_data, \ + filled_flag, input_rows_count); + NUMERIC_TYPE_TO_COLUMN_TYPE(DISPATCH) + DECIMAL_TYPE_TO_COLUMN_TYPE(DISPATCH) + TIME_TYPE_TO_COLUMN_TYPE(DISPATCH) +#undef DISPATCH + return Status::NotSupported( + fmt::format("argument_type {} not supported", data_type->get_name())); } }; diff --git a/be/src/vec/functions/function_hash.cpp b/be/src/vec/functions/function_hash.cpp index 18c7bcc..92e2a55 100644 --- a/be/src/vec/functions/function_hash.cpp +++ b/be/src/vec/functions/function_hash.cpp @@ -23,14 +23,14 @@ #include "util/hash_util.hpp" #include "vec/functions/function_variadic_arguments.h" #include "vec/functions/simple_function_factory.h" +#include "vec/utils/template_helpers.hpp" namespace doris::vectorized { struct MurmurHash2Impl64 { static constexpr auto name = "murmurHash2_64"; using ReturnType = UInt64; - static Status empty_apply(IColumn& icolumn, - size_t input_rows_count) { + static Status empty_apply(IColumn& icolumn, size_t input_rows_count) { ColumnVector<ReturnType>& vec_to = assert_cast<ColumnVector<ReturnType>&>(icolumn); vec_to.get_data().assign(input_rows_count, static_cast<ReturnType>(0xe28dbde7fe22e41c)); return Status::OK(); @@ -42,8 +42,8 @@ struct MurmurHash2Impl64 { return Status::OK(); } - static Status combine_apply(const IDataType* type, const IColumn* column, size_t input_rows_count, - IColumn& icolumn) { + static Status combine_apply(const IDataType* type, const IColumn* column, + size_t input_rows_count, IColumn& icolumn) { execute_any<false>(type, column, icolumn, input_rows_count); return Status::OK(); } @@ -58,7 +58,7 @@ struct MurmurHash2Impl64 { for (size_t i = 0; i < size; ++i) { ReturnType val = HashUtil::murmur_hash2_64( reinterpret_cast<const char*>(reinterpret_cast<const char*>(&vec_from[i])), - sizeof(vec_from[i]), 0); + sizeof(vec_from[i]), 0); if (first) col_to.insert_data(const_cast<const char*>(reinterpret_cast<char*>(&val)), 0); else @@ -137,38 +137,20 @@ struct MurmurHash2Impl64 { } template <bool first> - static Status execute_any(const IDataType* from_type, const IColumn* icolumn, - IColumn& col_to, size_t input_rows_count) { + static Status execute_any(const IDataType* from_type, const IColumn* icolumn, IColumn& col_to, + size_t input_rows_count) { WhichDataType which(from_type); - - if (which.is_uint8()) - execute_int_type<UInt8, first>(icolumn, col_to, input_rows_count); - else if (which.is_int16()) - execute_int_type<UInt16, first>(icolumn, col_to, input_rows_count); - else if (which.is_uint32()) - execute_int_type<UInt32, first>(icolumn, col_to, input_rows_count); - else if (which.is_uint64()) - execute_int_type<UInt64, first>(icolumn, col_to, input_rows_count); - else if (which.is_int8()) - execute_int_type<Int8, first>(icolumn, col_to, input_rows_count); - else if (which.is_int16()) - execute_int_type<Int16, first>(icolumn, col_to, input_rows_count); - else if (which.is_int32()) - execute_int_type<Int32, first>(icolumn, col_to, input_rows_count); - else if (which.is_int64()) - execute_int_type<Int64, first>(icolumn, col_to, input_rows_count); - else if (which.is_float32()) - execute_int_type<Float32, first>(icolumn, col_to, input_rows_count); - else if (which.is_float64()) - execute_int_type<Float64, first>(icolumn, col_to, input_rows_count); - else if (which.is_string()) - execute_string<first>(icolumn, col_to, input_rows_count); - else { - DCHECK(false); - return Status::NotSupported(fmt::format("Illegal column {} of argument of function {}", - icolumn->get_name(), name)); + if (which.is_string()) { + return execute_string<first>(icolumn, col_to, input_rows_count); } - return Status::OK(); + +#define DISPATCH(TYPE, COLUMN_TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return execute_int_type<TYPE, first>(icolumn, col_to, input_rows_count); + NUMERIC_TYPE_TO_COLUMN_TYPE(DISPATCH) +#undef DISPATCH + return Status::NotSupported( + fmt::format("argument_type {} not supported", from_type->get_name())); } }; using FunctionMurmurHash2_64 = FunctionVariadicArgumentsBase<DataTypeUInt64, MurmurHash2Impl64>; @@ -177,8 +159,7 @@ struct MurmurHash3Impl32 { static constexpr auto name = "murmur_hash3_32"; using ReturnType = Int32; - static Status empty_apply(IColumn& icolumn, - size_t input_rows_count) { + static Status empty_apply(IColumn& icolumn, size_t input_rows_count) { ColumnVector<ReturnType>& vec_to = assert_cast<ColumnVector<ReturnType>&>(icolumn); vec_to.get_data().assign(input_rows_count, static_cast<ReturnType>(0xe28dbde7fe22e41c)); return Status::OK(); @@ -189,8 +170,8 @@ struct MurmurHash3Impl32 { return execute<true>(type, column, input_rows_count, icolumn); } - static Status combine_apply(const IDataType* type, const IColumn* column, size_t input_rows_count, - IColumn& icolumn) { + static Status combine_apply(const IDataType* type, const IColumn* column, + size_t input_rows_count, IColumn& icolumn) { return execute<false>(type, column, input_rows_count, icolumn); } @@ -207,15 +188,14 @@ struct MurmurHash3Impl32 { if (first) { UInt32 val = HashUtil::murmur_hash3_32( reinterpret_cast<const char*>(&data[current_offset]), - offsets[i] - current_offset - 1, - HashUtil::MURMUR3_32_SEED); + offsets[i] - current_offset - 1, HashUtil::MURMUR3_32_SEED); col_to.insert_data(const_cast<const char*>(reinterpret_cast<char*>(&val)), 0); } else { assert_cast<ColumnVector<ReturnType>&>(col_to).get_data()[i] = HashUtil::murmur_hash3_32( - reinterpret_cast<const char*>(&data[current_offset]), - offsets[i] - current_offset - 1, - ext::bit_cast<UInt32>(col_to[i])); + reinterpret_cast<const char*>(&data[current_offset]), + offsets[i] - current_offset - 1, + ext::bit_cast<UInt32>(col_to[i])); } current_offset = offsets[i]; } @@ -224,17 +204,13 @@ struct MurmurHash3Impl32 { String value = col_from_const->get_value<String>().data(); for (size_t i = 0; i < input_rows_count; ++i) { if (first) { - UInt32 val = HashUtil::murmur_hash3_32( - value.data(), - value.size(), - HashUtil::MURMUR3_32_SEED); + UInt32 val = HashUtil::murmur_hash3_32(value.data(), value.size(), + HashUtil::MURMUR3_32_SEED); col_to.insert_data(const_cast<const char*>(reinterpret_cast<char*>(&val)), 0); } else { assert_cast<ColumnVector<ReturnType>&>(col_to).get_data()[i] = - HashUtil::murmur_hash3_32( - value.data(), - value.size(), - ext::bit_cast<UInt32>(col_to[i])); + HashUtil::murmur_hash3_32(value.data(), value.size(), + ext::bit_cast<UInt32>(col_to[i])); } } } else { diff --git a/be/src/vec/utils/template_helpers.hpp b/be/src/vec/utils/template_helpers.hpp new file mode 100644 index 0000000..4d4e1e2 --- /dev/null +++ b/be/src/vec/utils/template_helpers.hpp @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Helpers.h +// and modified by Doris + +#pragma once + +#include "http/http_status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/columns_number.h" +#include "vec/data_types/data_type.h" +#include "vec/functions/function.h" + +#define NUMERIC_TYPE_TO_COLUMN_TYPE(M) \ + M(UInt8, ColumnUInt8) \ + M(Int8, ColumnInt8) \ + M(Int16, ColumnInt16) \ + M(Int32, ColumnInt32) \ + M(Int64, ColumnInt64) \ + M(Int128, ColumnInt128) \ + M(Float32, ColumnFloat32) \ + M(Float64, ColumnFloat64) + +#define DECIMAL_TYPE_TO_COLUMN_TYPE(M) \ + M(Decimal32, ColumnDecimal<Decimal32>) \ + M(Decimal64, ColumnDecimal<Decimal64>) \ + M(Decimal128, ColumnDecimal<Decimal128>) + +#define STRING_TYPE_TO_COLUMN_TYPE(M) M(String, ColumnString) + +#define TIME_TYPE_TO_COLUMN_TYPE(M) \ + M(Date, ColumnInt64) \ + M(DateTime, ColumnInt64) + +#define TYPE_TO_COLUMN_TYPE(M) \ + NUMERIC_TYPE_TO_COLUMN_TYPE(M) \ + DECIMAL_TYPE_TO_COLUMN_TYPE(M) \ + STRING_TYPE_TO_COLUMN_TYPE(M) \ + TIME_TYPE_TO_COLUMN_TYPE(M) + +namespace doris::vectorized { + +template <template <typename> typename ClassTemplate, typename... TArgs> +IAggregateFunction* create_class_with_type(const IDataType& argument_type, TArgs&&... args) { + WhichDataType which(argument_type); +#define DISPATCH(TYPE, COLUMN_TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return new ClassTemplate<COLUMN_TYPE>(std::forward<TArgs>(args)...); + TYPE_TO_COLUMN_TYPE(DISPATCH) +#undef DISPATCH + return nullptr; +} + +} // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 5f435d1..ca97cbc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1469,25 +1469,43 @@ public class FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo // NDV // ndv return string - addBuiltin(AggregateFunction.createBuiltin("ndv", - Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR, - "_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", - "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t), - "_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", - "_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", - "_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", - true, false, true)); + addBuiltin(AggregateFunction.createBuiltin("ndv", Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR, + "_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t), + "_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + "_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + "_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + true, false, true)); - //APPROX_COUNT_DISTINCT - //alias of ndv, compute approx count distinct use HyperLogLog - addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct", - Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR, - "_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", - "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t), - "_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", - "_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", - "_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", - true, false, true)); + // vectorized + addBuiltin(AggregateFunction.createBuiltin("ndv", Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR, + "_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t), + "_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + "_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + "_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + true, false, true, true)); + + // APPROX_COUNT_DISTINCT + // alias of ndv, compute approx count distinct use HyperLogLog + addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct", Lists.newArrayList(t), Type.BIGINT, + Type.VARCHAR, + "_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t), + "_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + "_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + "_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + true, false, true)); + + // vectorized + addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct", Lists.newArrayList(t), Type.BIGINT, + Type.VARCHAR, + "_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t), + "_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + "_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + "_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + true, false, true, true)); // BITMAP_UNION_INT addBuiltin(AggregateFunction.createBuiltin(BITMAP_UNION_INT, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org