This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 143c408  [Feature][Vectorized] support aggregate function 
ndv()/approx_count_distinct() (#8044)
143c408 is described below

commit 143c4085ee58007954f3eef8910556f5b8ce6b39
Author: Pxl <952130...@qq.com>
AuthorDate: Wed Feb 16 14:30:13 2022 +0800

    [Feature][Vectorized] support aggregate function 
ndv()/approx_count_distinct() (#8044)
---
 be/src/vec/CMakeLists.txt                          |   1 +
 .../aggregate_function_approx_count_distinct.cpp   |  50 ++++++++++
 .../aggregate_function_approx_count_distinct.h     | 107 +++++++++++++++++++++
 .../aggregate_functions/aggregate_function_avg.cpp |   4 -
 .../aggregate_function_simple_factory.cpp          |   5 +
 .../aggregate_function_simple_factory.h            |   7 +-
 be/src/vec/common/string_ref.h                     |   6 +-
 be/src/vec/functions/function_case.h               |  55 ++---------
 be/src/vec/functions/function_coalesce.cpp         | 102 ++++++++------------
 be/src/vec/functions/function_hash.cpp             |  80 ++++++---------
 be/src/vec/utils/template_helpers.hpp              |  69 +++++++++++++
 .../java/org/apache/doris/catalog/FunctionSet.java |  54 +++++++----
 12 files changed, 352 insertions(+), 188 deletions(-)

diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index c526cf2..7f8878f 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -33,6 +33,7 @@ set(VEC_FILES
   aggregate_functions/aggregate_function_window.cpp
   aggregate_functions/aggregate_function_stddev.cpp
   aggregate_functions/aggregate_function_topn.cpp
+  aggregate_functions/aggregate_function_approx_count_distinct.cpp
   aggregate_functions/aggregate_function_simple_factory.cpp
   columns/collator.cpp
   columns/column.cpp
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
new file mode 100644
index 0000000..fc68d85
--- /dev/null
+++ 
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/aggregate_functions/aggregate_function_approx_count_distinct.h"
+
+#include "vec/utils/template_helpers.hpp"
+
+namespace doris::vectorized {
+
+AggregateFunctionPtr create_aggregate_function_approx_count_distinct(
+        const std::string& name, const DataTypes& argument_types, const Array& 
parameters,
+        const bool result_is_nullable) {
+    AggregateFunctionPtr res = nullptr;
+    WhichDataType which(argument_types[0]->is_nullable()
+                                ? reinterpret_cast<const 
DataTypeNullable*>(argument_types[0].get())
+                                          ->get_nested_type()
+                                : argument_types[0]);
+
+    
res.reset(create_class_with_type<AggregateFunctionApproxCountDistinct>(*argument_types[0],
+                                                                     
argument_types));
+
+    if (!res) {
+        LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate 
function {}",
+                                    argument_types[0]->get_name(), name);
+    }
+
+    return res;
+}
+
+void 
register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory&
 factory) {
+    factory.register_function("approx_count_distinct",
+                              create_aggregate_function_approx_count_distinct);
+    factory.register_alias("approx_count_distinct", "ndv");
+}
+
+} // namespace doris::vectorized
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h 
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
new file mode 100644
index 0000000..ed393af
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "exprs/anyval_util.h"
+#include "olap/hll.h"
+#include "udf/udf.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/common/string_ref.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+
+struct AggregateFunctionApproxCountDistinctData {
+    HyperLogLog hll_data;
+
+    void add(StringRef value) {
+        StringVal sv = value.to_string_val();
+        uint64_t hash_value = AnyValUtil::hash64_murmur(sv, 
HashUtil::MURMUR_SEED);
+        if (hash_value != 0) {
+            hll_data.update(hash_value);
+        }
+    }
+
+    void merge(const AggregateFunctionApproxCountDistinctData& rhs) {
+        hll_data.merge(rhs.hll_data);
+    }
+
+    void write(BufferWritable& buf) const {
+        std::string result;
+        result.resize(hll_data.max_serialized_size());
+        int size = hll_data.serialize((uint8_t*)result.data());
+        result.resize(size);
+        write_binary(result, buf);
+    }
+
+    void read(BufferReadable& buf) {
+        StringRef result;
+        read_binary(result, buf);
+        Slice data = Slice(result.data, result.size);
+        hll_data.deserialize(data);
+    }
+
+    int64_t get() const { return hll_data.estimate_cardinality(); }
+
+    void reset() { hll_data.clear(); }
+};
+
+template <typename ColumnDataType>
+class AggregateFunctionApproxCountDistinct final
+        : public IAggregateFunctionDataHelper<
+                  AggregateFunctionApproxCountDistinctData,
+                  AggregateFunctionApproxCountDistinct<ColumnDataType>> {
+public:
+    String get_name() const override { return "approx_count_distinct"; }
+
+    AggregateFunctionApproxCountDistinct(const DataTypes& argument_types_)
+            : 
IAggregateFunctionDataHelper<AggregateFunctionApproxCountDistinctData,
+                                           
AggregateFunctionApproxCountDistinct<ColumnDataType>>(
+                      argument_types_, {}) {}
+
+    DataTypePtr get_return_type() const override { return 
std::make_shared<DataTypeInt64>(); }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
+             Arena*) const override {
+        this->data(place).add(static_cast<const 
ColumnDataType*>(columns[0])->get_data_at(row_num));
+    }
+
+    void reset(AggregateDataPtr place) const override { 
this->data(place).reset(); }
+
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena*) const override {
+        this->data(place).merge(this->data(rhs));
+    }
+
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
+        this->data(place).write(buf);
+    }
+
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena*) const override {
+        this->data(place).read(buf);
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        auto& column = static_cast<ColumnInt64&>(to);
+        column.get_data().push_back(this->data(place).get());
+    }
+};
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
index bb7605b..61687af 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
@@ -27,8 +27,6 @@
 
 namespace doris::vectorized {
 
-namespace {
-
 template <typename T>
 struct Avg {
     using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, 
NearestFieldType<T>>;
@@ -60,8 +58,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const 
std::string& name,
     return res;
 }
 
-} // namespace
-
 void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
     factory.register_function("avg", create_aggregate_function_avg);
 }
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 4844000..87a52b9 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -23,7 +23,9 @@
 #include "vec/aggregate_functions/aggregate_function_reader.h"
 
 namespace doris::vectorized {
+
 class AggregateFunctionSimpleFactory;
+
 void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
 void 
register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& 
factory);
 void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& 
factory);
@@ -37,6 +39,8 @@ void 
register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& fac
 void 
register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& 
factory);
 void 
register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& 
factory);
 void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
+void 
register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory&
 factory);
+
 AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
     static std::once_flag oc;
     static AggregateFunctionSimpleFactory instance;
@@ -53,6 +57,7 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
         register_aggregate_function_window_rank(instance);
         register_aggregate_function_stddev_variance(instance);
         register_aggregate_function_topn(instance);
+        register_aggregate_function_approx_count_distinct(instance);
 
         // if you only register function with no nullable, and wants to add 
nullable automatically, you should place function above this line
         register_aggregate_function_combinator_null(instance);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
index 1bac4f1..833e52d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
@@ -46,6 +46,7 @@ private:
     AggregateFunctions aggregate_functions;
     AggregateFunctions nullable_aggregate_functions;
     std::unordered_map<std::string, std::string> function_alias;
+
 public:
     void register_nullable_function_combinator(const Creator& creator) {
         for (const auto& entity : aggregate_functions) {
@@ -86,13 +87,13 @@ public:
         if (nullable) {
             return nullable_aggregate_functions.find(name_str) == 
nullable_aggregate_functions.end()
                            ? nullptr
-                           : nullable_aggregate_functions[name_str](name_str, 
argument_types, parameters,
-                                                                
result_is_nullable);
+                           : nullable_aggregate_functions[name_str](name_str, 
argument_types,
+                                                                    
parameters, result_is_nullable);
         } else {
             return aggregate_functions.find(name_str) == 
aggregate_functions.end()
                            ? nullptr
                            : aggregate_functions[name_str](name_str, 
argument_types, parameters,
-                                                       result_is_nullable);
+                                                           result_is_nullable);
         }
     }
 
diff --git a/be/src/vec/common/string_ref.h b/be/src/vec/common/string_ref.h
index 727996e..5dd146e 100644
--- a/be/src/vec/common/string_ref.h
+++ b/be/src/vec/common/string_ref.h
@@ -55,9 +55,13 @@ struct StringRef {
 
     explicit operator std::string() const { return to_string(); }
 
-    StringVal to_string_val() const {
+    StringVal to_string_val() {
         return StringVal(reinterpret_cast<uint8_t*>(const_cast<char*>(data)), 
size);
     }
+
+    static StringRef from_string_val(StringVal sv) {
+        return StringRef(reinterpret_cast<char*>(sv.ptr), sv.len);
+    }
 };
 
 using StringRefs = std::vector<StringRef>;
diff --git a/be/src/vec/functions/function_case.h 
b/be/src/vec/functions/function_case.h
index 1113b5f..1b728e9 100644
--- a/be/src/vec/functions/function_case.h
+++ b/be/src/vec/functions/function_case.h
@@ -17,11 +17,11 @@
 
 #pragma once
 
-#include "vec/data_types/data_type_decimal.h"
 #include "vec/data_types/data_type_nullable.h"
 #include "vec/functions/function.h"
 #include "vec/functions/function_helpers.h"
 #include "vec/functions/simple_function_factory.h"
+#include "vec/utils/template_helpers.hpp"
 
 namespace doris::vectorized {
 
@@ -311,51 +311,14 @@ public:
                                     ? reinterpret_cast<const 
DataTypeNullable*>(data_type.get())
                                               ->get_nested_type()
                                     : data_type);
-
-        // TODO: use template traits here.
-        if (which.is_uint8()) {
-            return execute_get_when_null<ColumnUInt8>(data_type, block, 
arguments, result,
-                                                      input_rows_count);
-        } else if (which.is_int16()) {
-            return execute_get_when_null<ColumnInt16>(data_type, block, 
arguments, result,
-                                                      input_rows_count);
-        } else if (which.is_uint32()) {
-            return execute_get_when_null<ColumnUInt32>(data_type, block, 
arguments, result,
-                                                       input_rows_count);
-        } else if (which.is_uint64()) {
-            return execute_get_when_null<ColumnUInt64>(data_type, block, 
arguments, result,
-                                                       input_rows_count);
-        } else if (which.is_int8()) {
-            return execute_get_when_null<ColumnInt8>(data_type, block, 
arguments, result,
-                                                     input_rows_count);
-        } else if (which.is_int16()) {
-            return execute_get_when_null<ColumnInt16>(data_type, block, 
arguments, result,
-                                                      input_rows_count);
-        } else if (which.is_int32()) {
-            return execute_get_when_null<ColumnInt32>(data_type, block, 
arguments, result,
-                                                      input_rows_count);
-        } else if (which.is_int64()) {
-            return execute_get_when_null<ColumnInt64>(data_type, block, 
arguments, result,
-                                                      input_rows_count);
-        } else if (which.is_date_or_datetime()) {
-            return execute_get_when_null<ColumnVector<DateTime>>(data_type, 
block, arguments,
-                                                                 result, 
input_rows_count);
-        } else if (which.is_float32()) {
-            return execute_get_when_null<ColumnFloat32>(data_type, block, 
arguments, result,
-                                                        input_rows_count);
-        } else if (which.is_float64()) {
-            return execute_get_when_null<ColumnFloat64>(data_type, block, 
arguments, result,
-                                                        input_rows_count);
-        } else if (which.is_decimal()) {
-            return execute_get_when_null<ColumnDecimal<Decimal128>>(data_type, 
block, arguments,
-                                                                    result, 
input_rows_count);
-        } else if (which.is_string()) {
-            return execute_get_when_null<ColumnString>(data_type, block, 
arguments, result,
-                                                       input_rows_count);
-        } else {
-            return Status::NotSupported(fmt::format("Unexpected type {} of 
argument of function {}",
-                                                    data_type->get_name(), 
get_name()));
-        }
+#define DISPATCH(TYPE, COLUMN_TYPE)                                            
        \
+    if (which.idx == TypeIndex::TYPE)                                          
        \
+        return execute_get_when_null<COLUMN_TYPE>(data_type, block, arguments, 
result, \
+                                                  input_rows_count);
+        TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+        return Status::NotSupported(
+                fmt::format("argument_type {} not supported", 
data_type->get_name()));
     }
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
diff --git a/be/src/vec/functions/function_coalesce.cpp 
b/be/src/vec/functions/function_coalesce.cpp
index 4991fa8..91d6304 100644
--- a/be/src/vec/functions/function_coalesce.cpp
+++ b/be/src/vec/functions/function_coalesce.cpp
@@ -16,11 +16,10 @@
 // under the License.
 
 #include "udf/udf.h"
-#include "vec/data_types/data_type_nothing.h"
-#include "vec/data_types/data_type_number.h"
 #include "vec/data_types/get_least_supertype.h"
 #include "vec/functions/function_helpers.h"
 #include "vec/functions/simple_function_factory.h"
+#include "vec/utils/template_helpers.hpp"
 #include "vec/utils/util.hpp"
 
 namespace doris::vectorized {
@@ -53,11 +52,9 @@ public:
 
         res = res ? res : arguments[0];
 
-        const ColumnsWithTypeAndName is_not_null_col{
-                {nullptr, make_nullable(res), ""}
-        };
-        func_is_not_null = SimpleFunctionFactory::instance().
-                get_function("is_not_null_pred", is_not_null_col, 
std::make_shared<DataTypeUInt8>());
+        const ColumnsWithTypeAndName is_not_null_col {{nullptr, 
make_nullable(res), ""}};
+        func_is_not_null = SimpleFunctionFactory::instance().get_function(
+                "is_not_null_pred", is_not_null_col, 
std::make_shared<DataTypeUInt8>());
 
         return res;
     }
@@ -74,7 +71,8 @@ public:
             filtered_args.push_back(arguments[i]);
             if (!arg_type->is_nullable()) {
                 if (i == 0) { //if the first column not null, return it's 
directly
-                    block.get_by_position(result).column = 
block.get_by_position(arguments[0]).column;
+                    block.get_by_position(result).column =
+                            block.get_by_position(arguments[0]).column;
                     return Status::OK();
                 } else {
                     break;
@@ -84,8 +82,12 @@ public:
 
         size_t remaining_rows = input_rows_count;
         size_t argument_size = filtered_args.size();
-        std::vector<uint32_t> record_idx(input_rows_count, 0); //used to save 
column idx, record the result data of each row from which column
-        std::vector<uint8_t> filled_flags(input_rows_count, 0); //used to save 
filled flag, in order to check current row whether have filled data
+        std::vector<uint32_t> record_idx(
+                input_rows_count,
+                0); //used to save column idx, record the result data of each 
row from which column
+        std::vector<uint8_t> filled_flags(
+                input_rows_count,
+                0); //used to save filled flag, in order to check current row 
whether have filled data
 
         MutableColumnPtr result_column;
         if (!result_type->is_nullable()) {
@@ -104,7 +106,8 @@ public:
         }
 
         auto return_type = std::make_shared<DataTypeUInt8>();
-        auto null_map = ColumnUInt8::create(input_rows_count, 1);  //if 
null_map_data==1, the current row should be null
+        auto null_map = ColumnUInt8::create(
+                input_rows_count, 1); //if null_map_data==1, the current row 
should be null
         auto* __restrict null_map_data = null_map->get_data().data();
         ColumnPtr argument_columns[argument_size]; //use to save nested_column 
if is nullable column
 
@@ -119,17 +122,17 @@ public:
         }
 
         Block temporary_block {
-            ColumnsWithTypeAndName {
-                    block.get_by_position(filtered_args[0]),
-                    {nullptr, std::make_shared<DataTypeUInt8>(), ""}
-            }
-        };
+                ColumnsWithTypeAndName 
{block.get_by_position(filtered_args[0]),
+                                        {nullptr, 
std::make_shared<DataTypeUInt8>(), ""}}};
 
         for (size_t i = 0; i < argument_size && remaining_rows; ++i) {
-            temporary_block.get_by_position(0).column = 
block.get_by_position(filtered_args[i]).column;
+            temporary_block.get_by_position(0).column =
+                    block.get_by_position(filtered_args[i]).column;
             func_is_not_null->execute(context, temporary_block, {0}, 1, 
input_rows_count);
 
-            auto res_column = 
(*temporary_block.get_by_position(1).column->convert_to_full_column_if_const()).mutate();
+            auto res_column =
+                    
(*temporary_block.get_by_position(1).column->convert_to_full_column_if_const())
+                            .mutate();
             auto& res_map = 
assert_cast<ColumnVector<UInt8>*>(res_column.get())->get_data();
             auto* __restrict res = res_map.data();
 
@@ -152,7 +155,8 @@ public:
 
                 if (is_same_column_count == input_rows_count) {
                     if (result_type->is_nullable()) {
-                        block.get_by_position(result).column = 
make_nullable(argument_columns[i], false);
+                        block.get_by_position(result).column =
+                                make_nullable(argument_columns[i], false);
                     } else {
                         block.get_by_position(result).column = 
argument_columns[i];
                     }
@@ -170,7 +174,7 @@ public:
         }
 
         if (is_string_result) {
-            //if string type,  should according to the record results, fill in 
result one by one, 
+            //if string type,  should according to the record results, fill in 
result one by one,
             for (size_t row = 0; row < input_rows_count; ++row) {
                 if (null_map_data[row]) { //should be null
                     result_column->insert_default();
@@ -181,7 +185,8 @@ public:
         }
 
         if (result_type->is_nullable()) {
-            block.replace_by_position(result, 
ColumnNullable::create(std::move(result_column), std::move(null_map)));
+            block.replace_by_position(
+                    result, ColumnNullable::create(std::move(result_column), 
std::move(null_map)));
         } else {
             block.replace_by_position(result, std::move(result_column));
         }
@@ -198,18 +203,17 @@ public:
         auto* __restrict column_raw_data =
                 reinterpret_cast<const 
ColumnType*>(argument_column.get())->get_data().data();
 
-
         // Here it's SIMD thought the compiler automatically also
         // true: null_map_data[row]==0 && filled_idx[row]==0
         // if true, could filled current row data into result column
         for (size_t row = 0; row < input_rows_count; ++row) {
-            result_raw_data[row] += (!(null_map_data[row] | filled_flag[row])) 
* column_raw_data[row];
+            result_raw_data[row] +=
+                    (!(null_map_data[row] | filled_flag[row])) * 
column_raw_data[row];
             filled_flag[row] += (!(null_map_data[row] | filled_flag[row]));
         }
         return Status::OK();
     }
 
-    //TODO: this function is same as case when, should be replaced by macro
     Status filled_result_column(const DataTypePtr& data_type, 
MutableColumnPtr& result_column,
                                 ColumnPtr& argument_column, UInt8* __restrict 
null_map_data,
                                 UInt8* __restrict filled_flag, const size_t 
input_rows_count) {
@@ -217,46 +221,16 @@ public:
                                     ? reinterpret_cast<const 
DataTypeNullable*>(data_type.get())
                                               ->get_nested_type()
                                     : data_type);
-        if (which.is_uint8()) {
-            return insert_result_data<ColumnUInt8>(result_column, 
argument_column, null_map_data,
-                                                   filled_flag, 
input_rows_count);
-        } else if (which.is_int16()) {
-            return insert_result_data<ColumnInt16>(result_column, 
argument_column, null_map_data,
-                                                   filled_flag, 
input_rows_count);
-        } else if (which.is_uint32()) {
-            return insert_result_data<ColumnUInt32>(result_column, 
argument_column, null_map_data,
-                                                    filled_flag, 
input_rows_count);
-        } else if (which.is_uint64()) {
-            return insert_result_data<ColumnUInt64>(result_column, 
argument_column, null_map_data,
-                                                    filled_flag, 
input_rows_count);
-        } else if (which.is_int8()) {
-            return insert_result_data<ColumnInt8>(result_column, 
argument_column, null_map_data,
-                                                  filled_flag, 
input_rows_count);
-        } else if (which.is_int16()) {
-            return insert_result_data<ColumnInt16>(result_column, 
argument_column, null_map_data,
-                                                   filled_flag, 
input_rows_count);
-        } else if (which.is_int32()) {
-            return insert_result_data<ColumnInt32>(result_column, 
argument_column, null_map_data,
-                                                   filled_flag, 
input_rows_count);
-        } else if (which.is_int64()) {
-            return insert_result_data<ColumnInt64>(result_column, 
argument_column, null_map_data,
-                                                   filled_flag, 
input_rows_count);
-        } else if (which.is_date_or_datetime()) {
-            return insert_result_data<ColumnVector<DateTime>>(
-                    result_column, argument_column, null_map_data, 
filled_flag, input_rows_count);
-        } else if (which.is_float32()) {
-            return insert_result_data<ColumnFloat32>(result_column, 
argument_column, null_map_data,
-                                                     filled_flag, 
input_rows_count);
-        } else if (which.is_float64()) {
-            return insert_result_data<ColumnFloat64>(result_column, 
argument_column, null_map_data,
-                                                     filled_flag, 
input_rows_count);
-        } else if (which.is_decimal()) {
-            return insert_result_data<ColumnDecimal<Decimal128>>(
-                    result_column, argument_column, null_map_data, 
filled_flag, input_rows_count);
-        } else {
-            return Status::NotSupported(fmt::format("Unexpected type {} of 
argument of function {}",
-                                                    data_type->get_name(), 
get_name()));
-        }
+#define DISPATCH(TYPE, COLUMN_TYPE)                                            
               \
+    if (which.idx == TypeIndex::TYPE)                                          
               \
+        return insert_result_data<COLUMN_TYPE>(result_column, argument_column, 
null_map_data, \
+                                               filled_flag, input_rows_count);
+        NUMERIC_TYPE_TO_COLUMN_TYPE(DISPATCH)
+        DECIMAL_TYPE_TO_COLUMN_TYPE(DISPATCH)
+        TIME_TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+        return Status::NotSupported(
+                fmt::format("argument_type {} not supported", 
data_type->get_name()));
     }
 };
 
diff --git a/be/src/vec/functions/function_hash.cpp 
b/be/src/vec/functions/function_hash.cpp
index 18c7bcc..92e2a55 100644
--- a/be/src/vec/functions/function_hash.cpp
+++ b/be/src/vec/functions/function_hash.cpp
@@ -23,14 +23,14 @@
 #include "util/hash_util.hpp"
 #include "vec/functions/function_variadic_arguments.h"
 #include "vec/functions/simple_function_factory.h"
+#include "vec/utils/template_helpers.hpp"
 
 namespace doris::vectorized {
 struct MurmurHash2Impl64 {
     static constexpr auto name = "murmurHash2_64";
     using ReturnType = UInt64;
 
-    static Status empty_apply(IColumn& icolumn,
-                              size_t input_rows_count) {
+    static Status empty_apply(IColumn& icolumn, size_t input_rows_count) {
         ColumnVector<ReturnType>& vec_to = 
assert_cast<ColumnVector<ReturnType>&>(icolumn);
         vec_to.get_data().assign(input_rows_count, 
static_cast<ReturnType>(0xe28dbde7fe22e41c));
         return Status::OK();
@@ -42,8 +42,8 @@ struct MurmurHash2Impl64 {
         return Status::OK();
     }
 
-    static Status combine_apply(const IDataType* type, const IColumn* column, 
size_t input_rows_count,
-                                IColumn& icolumn) {
+    static Status combine_apply(const IDataType* type, const IColumn* column,
+                                size_t input_rows_count, IColumn& icolumn) {
         execute_any<false>(type, column, icolumn, input_rows_count);
         return Status::OK();
     }
@@ -58,7 +58,7 @@ struct MurmurHash2Impl64 {
             for (size_t i = 0; i < size; ++i) {
                 ReturnType val = HashUtil::murmur_hash2_64(
                         reinterpret_cast<const char*>(reinterpret_cast<const 
char*>(&vec_from[i])),
-                                                      sizeof(vec_from[i]), 0);
+                        sizeof(vec_from[i]), 0);
                 if (first)
                     col_to.insert_data(const_cast<const 
char*>(reinterpret_cast<char*>(&val)), 0);
                 else
@@ -137,38 +137,20 @@ struct MurmurHash2Impl64 {
     }
 
     template <bool first>
-    static Status execute_any(const IDataType* from_type, const IColumn* 
icolumn,
-                              IColumn& col_to, size_t input_rows_count) {
+    static Status execute_any(const IDataType* from_type, const IColumn* 
icolumn, IColumn& col_to,
+                              size_t input_rows_count) {
         WhichDataType which(from_type);
-
-        if (which.is_uint8())
-            execute_int_type<UInt8, first>(icolumn, col_to, input_rows_count);
-        else if (which.is_int16())
-            execute_int_type<UInt16, first>(icolumn, col_to, input_rows_count);
-        else if (which.is_uint32())
-            execute_int_type<UInt32, first>(icolumn, col_to, input_rows_count);
-        else if (which.is_uint64())
-            execute_int_type<UInt64, first>(icolumn, col_to, input_rows_count);
-        else if (which.is_int8())
-            execute_int_type<Int8, first>(icolumn, col_to, input_rows_count);
-        else if (which.is_int16())
-            execute_int_type<Int16, first>(icolumn, col_to, input_rows_count);
-        else if (which.is_int32())
-            execute_int_type<Int32, first>(icolumn, col_to, input_rows_count);
-        else if (which.is_int64())
-            execute_int_type<Int64, first>(icolumn, col_to, input_rows_count);
-        else if (which.is_float32())
-            execute_int_type<Float32, first>(icolumn, col_to, 
input_rows_count);
-        else if (which.is_float64())
-            execute_int_type<Float64, first>(icolumn, col_to, 
input_rows_count);
-        else if (which.is_string())
-            execute_string<first>(icolumn, col_to, input_rows_count);
-        else {
-            DCHECK(false);
-            return Status::NotSupported(fmt::format("Illegal column {} of 
argument of function {}",
-                                                    icolumn->get_name(), 
name));
+        if (which.is_string()) {
+            return execute_string<first>(icolumn, col_to, input_rows_count);
         }
-        return Status::OK();
+
+#define DISPATCH(TYPE, COLUMN_TYPE)   \
+    if (which.idx == TypeIndex::TYPE) \
+        return execute_int_type<TYPE, first>(icolumn, col_to, 
input_rows_count);
+        NUMERIC_TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+        return Status::NotSupported(
+                fmt::format("argument_type {} not supported", 
from_type->get_name()));
     }
 };
 using FunctionMurmurHash2_64 = FunctionVariadicArgumentsBase<DataTypeUInt64, 
MurmurHash2Impl64>;
@@ -177,8 +159,7 @@ struct MurmurHash3Impl32 {
     static constexpr auto name = "murmur_hash3_32";
     using ReturnType = Int32;
 
-    static Status empty_apply(IColumn& icolumn,
-                              size_t input_rows_count) {
+    static Status empty_apply(IColumn& icolumn, size_t input_rows_count) {
         ColumnVector<ReturnType>& vec_to = 
assert_cast<ColumnVector<ReturnType>&>(icolumn);
         vec_to.get_data().assign(input_rows_count, 
static_cast<ReturnType>(0xe28dbde7fe22e41c));
         return Status::OK();
@@ -189,8 +170,8 @@ struct MurmurHash3Impl32 {
         return execute<true>(type, column, input_rows_count, icolumn);
     }
 
-    static Status combine_apply(const IDataType* type, const IColumn* column, 
size_t input_rows_count,
-                                IColumn& icolumn) {
+    static Status combine_apply(const IDataType* type, const IColumn* column,
+                                size_t input_rows_count, IColumn& icolumn) {
         return execute<false>(type, column, input_rows_count, icolumn);
     }
 
@@ -207,15 +188,14 @@ struct MurmurHash3Impl32 {
                 if (first) {
                     UInt32 val = HashUtil::murmur_hash3_32(
                             reinterpret_cast<const 
char*>(&data[current_offset]),
-                            offsets[i] - current_offset - 1,
-                            HashUtil::MURMUR3_32_SEED);
+                            offsets[i] - current_offset - 1, 
HashUtil::MURMUR3_32_SEED);
                     col_to.insert_data(const_cast<const 
char*>(reinterpret_cast<char*>(&val)), 0);
                 } else {
                     
assert_cast<ColumnVector<ReturnType>&>(col_to).get_data()[i] =
                             HashUtil::murmur_hash3_32(
-                            reinterpret_cast<const 
char*>(&data[current_offset]),
-                            offsets[i] - current_offset - 1,
-                            ext::bit_cast<UInt32>(col_to[i]));
+                                    reinterpret_cast<const 
char*>(&data[current_offset]),
+                                    offsets[i] - current_offset - 1,
+                                    ext::bit_cast<UInt32>(col_to[i]));
                 }
                 current_offset = offsets[i];
             }
@@ -224,17 +204,13 @@ struct MurmurHash3Impl32 {
             String value = col_from_const->get_value<String>().data();
             for (size_t i = 0; i < input_rows_count; ++i) {
                 if (first) {
-                    UInt32 val = HashUtil::murmur_hash3_32(
-                            value.data(),
-                            value.size(),
-                            HashUtil::MURMUR3_32_SEED);
+                    UInt32 val = HashUtil::murmur_hash3_32(value.data(), 
value.size(),
+                                                           
HashUtil::MURMUR3_32_SEED);
                     col_to.insert_data(const_cast<const 
char*>(reinterpret_cast<char*>(&val)), 0);
                 } else {
                     
assert_cast<ColumnVector<ReturnType>&>(col_to).get_data()[i] =
-                            HashUtil::murmur_hash3_32(
-                            value.data(),
-                            value.size(),
-                            ext::bit_cast<UInt32>(col_to[i]));
+                            HashUtil::murmur_hash3_32(value.data(), 
value.size(),
+                                                      
ext::bit_cast<UInt32>(col_to[i]));
                 }
             }
         } else {
diff --git a/be/src/vec/utils/template_helpers.hpp 
b/be/src/vec/utils/template_helpers.hpp
new file mode 100644
index 0000000..4d4e1e2
--- /dev/null
+++ b/be/src/vec/utils/template_helpers.hpp
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// 
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Helpers.h
+// and modified by Doris
+
+#pragma once
+
+#include "http/http_status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/columns_number.h"
+#include "vec/data_types/data_type.h"
+#include "vec/functions/function.h"
+
+#define NUMERIC_TYPE_TO_COLUMN_TYPE(M) \
+    M(UInt8, ColumnUInt8)              \
+    M(Int8, ColumnInt8)                \
+    M(Int16, ColumnInt16)              \
+    M(Int32, ColumnInt32)              \
+    M(Int64, ColumnInt64)              \
+    M(Int128, ColumnInt128)            \
+    M(Float32, ColumnFloat32)          \
+    M(Float64, ColumnFloat64)
+
+#define DECIMAL_TYPE_TO_COLUMN_TYPE(M)     \
+    M(Decimal32, ColumnDecimal<Decimal32>) \
+    M(Decimal64, ColumnDecimal<Decimal64>) \
+    M(Decimal128, ColumnDecimal<Decimal128>)
+
+#define STRING_TYPE_TO_COLUMN_TYPE(M) M(String, ColumnString)
+
+#define TIME_TYPE_TO_COLUMN_TYPE(M) \
+    M(Date, ColumnInt64)            \
+    M(DateTime, ColumnInt64)
+
+#define TYPE_TO_COLUMN_TYPE(M)     \
+    NUMERIC_TYPE_TO_COLUMN_TYPE(M) \
+    DECIMAL_TYPE_TO_COLUMN_TYPE(M) \
+    STRING_TYPE_TO_COLUMN_TYPE(M)  \
+    TIME_TYPE_TO_COLUMN_TYPE(M)
+
+namespace doris::vectorized {
+
+template <template <typename> typename ClassTemplate, typename... TArgs>
+IAggregateFunction* create_class_with_type(const IDataType& argument_type, 
TArgs&&... args) {
+    WhichDataType which(argument_type);
+#define DISPATCH(TYPE, COLUMN_TYPE)   \
+    if (which.idx == TypeIndex::TYPE) \
+        return new ClassTemplate<COLUMN_TYPE>(std::forward<TArgs>(args)...);
+    TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+    return nullptr;
+}
+
+} // namespace  doris::vectorized
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index 5f435d1..ca97cbc 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1469,25 +1469,43 @@ public class 
FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo
 
             // NDV
             // ndv return string
-            addBuiltin(AggregateFunction.createBuiltin("ndv",
-                    Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
-                    
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
-                    "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t),
-                    
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
-                    
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
-                    
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
-                    true, false, true));
+            addBuiltin(AggregateFunction.createBuiltin("ndv", 
Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
+                            
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+                            "_ZN5doris12HllFunctions" + 
HLL_UPDATE_SYMBOL.get(t),
+                            
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+                            
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+                            
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+                            true, false, true));
 
-            //APPROX_COUNT_DISTINCT
-            //alias of ndv, compute approx count distinct use HyperLogLog
-            addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct",
-                    Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
-                    
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
-                    "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t),
-                    
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
-                    
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
-                    
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
-                    true, false, true));
+            // vectorized
+            addBuiltin(AggregateFunction.createBuiltin("ndv", 
Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
+                            
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+                            "_ZN5doris12HllFunctions" + 
HLL_UPDATE_SYMBOL.get(t),
+                            
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+                            
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+                            
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+                            true, false, true, true));
+
+            // APPROX_COUNT_DISTINCT
+            // alias of ndv, compute approx count distinct use HyperLogLog
+            
addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct", 
Lists.newArrayList(t), Type.BIGINT,
+                            Type.VARCHAR,
+                            
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+                            "_ZN5doris12HllFunctions" + 
HLL_UPDATE_SYMBOL.get(t),
+                            
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+                            
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+                            
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+                            true, false, true));
+
+            // vectorized
+            
addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct", 
Lists.newArrayList(t), Type.BIGINT,
+                            Type.VARCHAR,
+                            
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+                            "_ZN5doris12HllFunctions" + 
HLL_UPDATE_SYMBOL.get(t),
+                            
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+                            
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+                            
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+                            true, false, true, true));
 
             // BITMAP_UNION_INT
             addBuiltin(AggregateFunction.createBuiltin(BITMAP_UNION_INT,

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to