This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 24ef60b491 [Opt](exec) opt aggreate function performance in nullable 
column
24ef60b491 is described below

commit 24ef60b491801e91ff37a6b34642d47f8ab46604
Author: HappenLee <happen...@hotmail.com>
AuthorDate: Thu Feb 16 22:26:12 2023 +0800

    [Opt](exec) opt aggreate function performance in nullable column
---
 .../aggregate_functions/aggregate_function_avg.cpp |  21 +-
 .../aggregate_functions/aggregate_function_count.h |   3 +-
 .../aggregate_function_min_max.cpp                 |   1 -
 .../aggregate_function_null.cpp                    |   1 -
 .../aggregate_functions/aggregate_function_null.h  | 245 +++++++++++++++++++++
 .../aggregate_functions/aggregate_function_sum.cpp |  24 +-
 be/src/vec/aggregate_functions/helpers.h           | 149 ++++---------
 be/src/vec/data_types/data_type_nullable.cpp       |  12 +
 be/src/vec/data_types/data_type_nullable.h         |   1 +
 9 files changed, 339 insertions(+), 118 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
index 5875b831f3..7f9295d8e7 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
@@ -45,11 +45,23 @@ AggregateFunctionPtr create_aggregate_function_avg(const 
std::string& name,
 
     AggregateFunctionPtr res;
     DataTypePtr data_type = argument_types[0];
-    if (is_decimal(data_type)) {
-        res.reset(
-                create_with_decimal_type<AggregateFuncAvg>(*data_type, 
*data_type, argument_types));
+    if (data_type->is_nullable()) {
+        auto no_null_argument_types = remove_nullable(argument_types);
+        if (is_decimal(no_null_argument_types[0])) {
+            res.reset(create_with_decimal_type_null<AggregateFuncAvg>(
+                    no_null_argument_types, parameters, 
*no_null_argument_types[0],
+                    no_null_argument_types));
+        } else {
+            res.reset(create_with_numeric_type_null<AggregateFuncAvg>(
+                    no_null_argument_types, parameters, 
no_null_argument_types));
+        }
     } else {
-        res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, 
argument_types));
+        if (is_decimal(data_type)) {
+            res.reset(create_with_decimal_type<AggregateFuncAvg>(*data_type, 
*data_type,
+                                                                 
argument_types));
+        } else {
+            res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, 
argument_types));
+        }
     }
 
     if (!res) {
@@ -61,5 +73,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const 
std::string& name,
 
 void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
     factory.register_function("avg", create_aggregate_function_avg);
+    factory.register_function("avg", create_aggregate_function_avg, true);
 }
 } // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h 
b/be/src/vec/aggregate_functions/aggregate_function_count.h
index 960d4111cb..bc87e4bb10 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_count.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_count.h
@@ -121,7 +121,8 @@ public:
     DataTypePtr get_serialized_type() const override { return 
std::make_shared<DataTypeUInt64>(); }
 };
 
-/// Simply count number of not-NULL values.
+// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of 
AggregateFunctionCount
+// Simply count number of not-NULL values.
 class AggregateFunctionCountNotNullUnary final
         : public IAggregateFunctionDataHelper<AggregateFunctionCountData,
                                               
AggregateFunctionCountNotNullUnary> {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
index 83045dbd00..a01e2ce51a 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
@@ -25,7 +25,6 @@
 #include "vec/aggregate_functions/helpers.h"
 
 namespace doris::vectorized {
-
 /// min, max, any
 template <template <typename, bool> class AggregateFunctionTemplate, template 
<typename> class Data>
 static IAggregateFunction* create_aggregate_function_single_value(const 
String& name,
diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_null.cpp
index 495cefcb84..8ae2368864 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_null.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_null.cpp
@@ -85,7 +85,6 @@ public:
 };
 
 void 
register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& 
factory) {
-    // 
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorNull>());
     AggregateFunctionCreator creator = [&](const std::string& name, const 
DataTypes& types,
                                            const Array& params, const bool 
result_is_nullable) {
         auto function_combinator = 
std::make_shared<AggregateFunctionCombinatorNull>();
diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h 
b/be/src/vec/aggregate_functions/aggregate_function_null.h
index 86fe7734e1..69642d0deb 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_null.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_null.h
@@ -40,6 +40,7 @@ namespace doris::vectorized {
 /// If all rows had NULL, the behaviour is determined by "result_is_nullable" 
template parameter.
 ///  true - return NULL; false - return value from empty aggregation state of 
nested function.
 
+// TODO: only keep class xxxInline after we support all aggregate function
 template <bool result_is_nullable, typename Derived>
 class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived> {
 protected:
@@ -409,4 +410,248 @@ private:
             is_nullable; /// Plain array is better than std::vector due to one 
indirection less.
 };
 
+template <typename NestFunction, bool result_is_nullable, typename Derived>
+class AggregateFunctionNullBaseInline : public 
IAggregateFunctionHelper<Derived> {
+protected:
+    std::unique_ptr<NestFunction> nested_function;
+    size_t prefix_size;
+
+    /** In addition to data for nested aggregate function, we keep a flag
+      *  indicating - was there at least one non-NULL value accumulated.
+      * In case of no not-NULL values, the function will return NULL.
+      *
+      * We use prefix_size bytes for flag to satisfy the alignment requirement 
of nested state.
+      */
+
+    AggregateDataPtr nested_place(AggregateDataPtr __restrict place) const 
noexcept {
+        return place + prefix_size;
+    }
+
+    ConstAggregateDataPtr nested_place(ConstAggregateDataPtr __restrict place) 
const noexcept {
+        return place + prefix_size;
+    }
+
+    static void init_flag(AggregateDataPtr __restrict place) noexcept {
+        if constexpr (result_is_nullable) {
+            place[0] = false;
+        }
+    }
+
+    static void set_flag(AggregateDataPtr __restrict place) noexcept {
+        if constexpr (result_is_nullable) {
+            place[0] = true;
+        }
+    }
+
+    static bool get_flag(ConstAggregateDataPtr __restrict place) noexcept {
+        return result_is_nullable ? place[0] : true;
+    }
+
+public:
+    AggregateFunctionNullBaseInline(IAggregateFunction* nested_function_,
+                                    const DataTypes& arguments, const Array& 
params)
+            : IAggregateFunctionHelper<Derived>(arguments, params),
+              nested_function {assert_cast<NestFunction*>(nested_function_)} {
+        if (result_is_nullable) {
+            prefix_size = nested_function->align_of_data();
+        } else {
+            prefix_size = 0;
+        }
+    }
+
+    String get_name() const override {
+        /// This is just a wrapper. The function for Nullable arguments is 
named the same as the nested function itself.
+        return nested_function->get_name();
+    }
+
+    DataTypePtr get_return_type() const override {
+        return result_is_nullable ? 
make_nullable(nested_function->get_return_type())
+                                  : nested_function->get_return_type();
+    }
+
+    void create(AggregateDataPtr __restrict place) const override {
+        init_flag(place);
+        nested_function->create(nested_place(place));
+    }
+
+    void destroy(AggregateDataPtr __restrict place) const noexcept override {
+        nested_function->destroy(nested_place(place));
+    }
+    void reset(AggregateDataPtr place) const override {
+        init_flag(place);
+        nested_function->reset(nested_place(place));
+    }
+
+    bool has_trivial_destructor() const override {
+        return nested_function->has_trivial_destructor();
+    }
+
+    size_t size_of_data() const override { return prefix_size + 
nested_function->size_of_data(); }
+
+    size_t align_of_data() const override { return 
nested_function->align_of_data(); }
+
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena* arena) const override {
+        if (result_is_nullable && get_flag(rhs)) {
+            set_flag(place);
+        }
+
+        nested_function->merge(nested_place(place), nested_place(rhs), arena);
+    }
+
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
+        bool flag = get_flag(place);
+        if (result_is_nullable) {
+            write_binary(flag, buf);
+        }
+        if (flag) {
+            nested_function->serialize(nested_place(place), buf);
+        }
+    }
+
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena* arena) const override {
+        bool flag = true;
+        if (result_is_nullable) {
+            read_binary(flag, buf);
+        }
+        if (flag) {
+            set_flag(place);
+            nested_function->deserialize(nested_place(place), buf, arena);
+        }
+    }
+
+    void deserialize_and_merge(AggregateDataPtr __restrict place, 
BufferReadable& buf,
+                               Arena* arena) const override {
+        bool flag = true;
+        if (result_is_nullable) {
+            read_binary(flag, buf);
+        }
+        if (flag) {
+            set_flag(place);
+            nested_function->deserialize_and_merge(nested_place(place), buf, 
arena);
+        }
+    }
+
+    void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, 
const IColumn& column,
+                                           Arena* arena) const override {
+        size_t num_rows = column.size();
+        for (size_t i = 0; i != num_rows; ++i) {
+            VectorBufferReader buffer_reader(
+                    (assert_cast<const ColumnString&>(column)).get_data_at(i));
+            deserialize_and_merge(place, buffer_reader, arena);
+        }
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        if constexpr (result_is_nullable) {
+            ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
+            if (get_flag(place)) {
+                nested_function->insert_result_into(nested_place(place),
+                                                    
to_concrete.get_nested_column());
+                to_concrete.get_null_map_data().push_back(0);
+            } else {
+                to_concrete.insert_default();
+            }
+        } else {
+            nested_function->insert_result_into(nested_place(place), to);
+        }
+    }
+
+    bool allocates_memory_in_arena() const override {
+        return nested_function->allocates_memory_in_arena();
+    }
+
+    bool is_state() const override { return nested_function->is_state(); }
+};
+
+/** There are two cases: for single argument and variadic.
+  * Code for single argument is much more efficient.
+  */
+template <typename NestFuction, bool result_is_nullable>
+class AggregateFunctionNullUnaryInline final
+        : public AggregateFunctionNullBaseInline<
+                  NestFuction, result_is_nullable,
+                  AggregateFunctionNullUnaryInline<NestFuction, 
result_is_nullable>> {
+public:
+    AggregateFunctionNullUnaryInline(IAggregateFunction* nested_function_,
+                                     const DataTypes& arguments, const Array& 
params)
+            : AggregateFunctionNullBaseInline<
+                      NestFuction, result_is_nullable,
+                      AggregateFunctionNullUnaryInline<NestFuction, 
result_is_nullable>>(
+                      nested_function_, arguments, params) {}
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
+             Arena* arena) const override {
+        const ColumnNullable* column = assert_cast<const 
ColumnNullable*>(columns[0]);
+        if (!column->is_null_at(row_num)) {
+            this->set_flag(place);
+            const IColumn* nested_column = &column->get_nested_column();
+            this->nested_function->add(this->nested_place(place), 
&nested_column, row_num, arena);
+        }
+    }
+
+    void add_not_nullable(AggregateDataPtr __restrict place, const IColumn** 
columns,
+                          size_t row_num, Arena* arena) const {
+        const ColumnNullable* column = assert_cast<const 
ColumnNullable*>(columns[0]);
+        this->set_flag(place);
+        const IColumn* nested_column = &column->get_nested_column();
+        this->nested_function->add(this->nested_place(place), &nested_column, 
row_num, arena);
+    }
+
+    void add_batch(size_t batch_size, AggregateDataPtr* places, size_t 
place_offset,
+                   const IColumn** columns, Arena* arena, bool agg_many) const 
override {
+        const ColumnNullable* column = assert_cast<const 
ColumnNullable*>(columns[0]);
+        // The overhead introduced is negligible here, just an extra memory 
read from NullMap
+        const auto* __restrict null_map_data = 
column->get_null_map_data().data();
+        const IColumn* nested_column = &column->get_nested_column();
+        for (int i = 0; i < batch_size; ++i) {
+            if (!null_map_data[i]) {
+                AggregateDataPtr __restrict place = places[i] + place_offset;
+                this->set_flag(place);
+                this->nested_function->add(this->nested_place(place), 
&nested_column, i, arena);
+            }
+        }
+    }
+
+    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, 
const IColumn** columns,
+                                Arena* arena) const override {
+        const ColumnNullable* column = assert_cast<const 
ColumnNullable*>(columns[0]);
+        bool has_null = column->has_null();
+
+        if (has_null) {
+            for (size_t i = 0; i < batch_size; ++i) {
+                if (!column->is_null_at(i)) {
+                    this->set_flag(place);
+                    this->add(place, columns, i, arena);
+                }
+            }
+        } else {
+            this->set_flag(place);
+            const IColumn* nested_column = &column->get_nested_column();
+            this->nested_function->add_batch_single_place(batch_size, 
this->nested_place(place),
+                                                          &nested_column, 
arena);
+        }
+    }
+
+    void add_batch_range(size_t batch_begin, size_t batch_end, 
AggregateDataPtr place,
+                         const IColumn** columns, Arena* arena, bool has_null) 
override {
+        const ColumnNullable* column = assert_cast<const 
ColumnNullable*>(columns[0]);
+
+        if (has_null) {
+            for (size_t i = batch_begin; i <= batch_end; ++i) {
+                if (!column->is_null_at(i)) {
+                    this->set_flag(place);
+                    this->add(place, columns, i, arena);
+                }
+            }
+        } else {
+            this->set_flag(place);
+            const IColumn* nested_column = &column->get_nested_column();
+            this->nested_function->add_batch_range(batch_begin, batch_end,
+                                                   this->nested_place(place), 
&nested_column, arena,
+                                                   false);
+        }
+    }
+};
 } // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
index ca40e4196c..75d4d36414 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
@@ -25,6 +25,7 @@
 #include "common/logging.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
 #include "vec/aggregate_functions/helpers.h"
+#include "vec/data_types/data_type_nullable.h"
 
 namespace doris::vectorized {
 
@@ -45,15 +46,24 @@ AggregateFunctionPtr create_aggregate_function_sum(const 
std::string& name,
                                                    const DataTypes& 
argument_types,
                                                    const Array& parameters,
                                                    const bool 
result_is_nullable) {
-    // assert_no_parameters(name, parameters);
-    // assert_unary(name, argument_types);
-
     AggregateFunctionPtr res;
     DataTypePtr data_type = argument_types[0];
-    if (is_decimal(data_type)) {
-        res.reset(create_with_decimal_type<Function>(*data_type, *data_type, 
argument_types));
+    if (data_type->is_nullable()) {
+        auto no_null_argument_types = remove_nullable(argument_types);
+        if (is_decimal(no_null_argument_types[0])) {
+            
res.reset(create_with_decimal_type_null<Function>(no_null_argument_types, 
parameters,
+                                                              
*no_null_argument_types[0],
+                                                              
no_null_argument_types));
+        } else {
+            
res.reset(create_with_numeric_type_null<Function>(no_null_argument_types, 
parameters,
+                                                              
no_null_argument_types));
+        }
     } else {
-        res.reset(create_with_numeric_type<Function>(*data_type, 
argument_types));
+        if (is_decimal(data_type)) {
+            res.reset(create_with_decimal_type<Function>(*data_type, 
*data_type, argument_types));
+        } else {
+            res.reset(create_with_numeric_type<Function>(*data_type, 
argument_types));
+        }
     }
 
     if (!res) {
@@ -84,6 +94,8 @@ AggregateFunctionPtr 
create_aggregate_function_sum_reader(const std::string& nam
 
 void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) {
     factory.register_function("sum", 
create_aggregate_function_sum<AggregateFunctionSumSimple>);
+    factory.register_function("sum", 
create_aggregate_function_sum<AggregateFunctionSumSimple>,
+                              true);
 }
 
 } // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/helpers.h 
b/be/src/vec/aggregate_functions/helpers.h
index 36e11f7011..0970a860a6 100644
--- a/be/src/vec/aggregate_functions/helpers.h
+++ b/be/src/vec/aggregate_functions/helpers.h
@@ -21,8 +21,10 @@
 #pragma once
 
 #include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_null.h"
 #include "vec/data_types/data_type.h"
 
+// TODO: Should we support decimal in numeric types?
 #define FOR_NUMERIC_TYPES(M) \
     M(UInt8)                 \
     M(UInt16)                \
@@ -36,6 +38,12 @@
     M(Float32)               \
     M(Float64)
 
+#define FOR_DECIMAL_TYPES(M) \
+    M(Decimal32)             \
+    M(Decimal64)             \
+    M(Decimal128)            \
+    M(Decimal128I)
+
 namespace doris::vectorized {
 
 /** Create an aggregate function with a numeric type in the template 
parameter, depending on the type of the argument.
@@ -49,12 +57,20 @@ static IAggregateFunction* create_with_numeric_type(const 
IDataType& argument_ty
         return new 
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
     FOR_NUMERIC_TYPES(DISPATCH)
 #undef DISPATCH
-    if (which.idx == TypeIndex::Enum8) {
-        return new 
AggregateFunctionTemplate<Int8>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Enum16) {
-        return new 
AggregateFunctionTemplate<Int16>(std::forward<TArgs>(args)...);
-    }
+    return nullptr;
+}
+
+template <template <typename> class AggregateFunctionTemplate, typename... 
TArgs>
+static IAggregateFunction* create_with_numeric_type_null(const DataTypes& 
argument_types,
+                                                         const Array& params, 
TArgs&&... args) {
+    WhichDataType which(argument_types[0]);
+#define DISPATCH(TYPE)                                                         
                    \
+    if (which.idx == TypeIndex::TYPE)                                          
                    \
+        return new 
AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>(        
\
+                new 
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \
+                params);
+    FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
     return nullptr;
 }
 
@@ -68,12 +84,6 @@ static IAggregateFunction* create_with_numeric_type(const 
IDataType& argument_ty
         return new AggregateFunctionTemplate<TYPE, 
bool_param>(std::forward<TArgs>(args)...);
     FOR_NUMERIC_TYPES(DISPATCH)
 #undef DISPATCH
-    if (which.idx == TypeIndex::Enum8) {
-        return new AggregateFunctionTemplate<Int8, 
bool_param>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Enum16) {
-        return new AggregateFunctionTemplate<Int16, 
bool_param>(std::forward<TArgs>(args)...);
-    }
     return nullptr;
 }
 
@@ -87,12 +97,6 @@ static IAggregateFunction* create_with_numeric_type(const 
IDataType& argument_ty
         return new AggregateFunctionTemplate<TYPE, 
Data>(std::forward<TArgs>(args)...);
     FOR_NUMERIC_TYPES(DISPATCH)
 #undef DISPATCH
-    if (which.idx == TypeIndex::Enum8) {
-        return new AggregateFunctionTemplate<Int8, 
Data>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Enum16) {
-        return new AggregateFunctionTemplate<Int16, 
Data>(std::forward<TArgs>(args)...);
-    }
     return nullptr;
 }
 
@@ -106,12 +110,6 @@ static IAggregateFunction* create_with_numeric_type(const 
IDataType& argument_ty
         return new AggregateFunctionTemplate<TYPE, 
Data<TYPE>>(std::forward<TArgs>(args)...);
     FOR_NUMERIC_TYPES(DISPATCH)
 #undef DISPATCH
-    if (which.idx == TypeIndex::Enum8) {
-        return new AggregateFunctionTemplate<Int8, 
Data<Int8>>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Enum16) {
-        return new AggregateFunctionTemplate<Int16, 
Data<Int16>>(std::forward<TArgs>(args)...);
-    }
     return nullptr;
 }
 
@@ -125,70 +123,32 @@ static IAggregateFunction* create_with_numeric_type(const 
IDataType& argument_ty
         return new 
AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...);
     FOR_NUMERIC_TYPES(DISPATCH)
 #undef DISPATCH
-    // if (which.idx == TypeIndex::Enum8) return new 
AggregateFunctionTemplate<Data<Int8>>(std::forward<TArgs>(args)...);
-    // if (which.idx == TypeIndex::Enum16) return new 
AggregateFunctionTemplate<Data<Int16>>(std::forward<TArgs>(args)...);
-    return nullptr;
-}
-
-template <template <typename, typename> class AggregateFunctionTemplate,
-          template <typename> class Data, typename... TArgs>
-static IAggregateFunction* create_with_unsigned_integer_type(const IDataType& 
argument_type,
-                                                             TArgs&&... args) {
-    WhichDataType which(argument_type);
-    if (which.idx == TypeIndex::UInt8) {
-        return new AggregateFunctionTemplate<UInt8, 
Data<UInt8>>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::UInt16) {
-        return new AggregateFunctionTemplate<UInt16, 
Data<UInt16>>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::UInt32) {
-        return new AggregateFunctionTemplate<UInt32, 
Data<UInt32>>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::UInt64) {
-        return new AggregateFunctionTemplate<UInt64, 
Data<UInt64>>(std::forward<TArgs>(args)...);
-    }
     return nullptr;
 }
 
 template <template <typename> class AggregateFunctionTemplate, typename... 
TArgs>
-static IAggregateFunction* create_with_numeric_based_type(const IDataType& 
argument_type,
-                                                          TArgs&&... args) {
-    IAggregateFunction* f = 
create_with_numeric_type<AggregateFunctionTemplate>(
-            argument_type, std::forward<TArgs>(args)...);
-    if (f) {
-        return f;
-    }
-
-    /// expects that DataTypeDate based on UInt16, DataTypeDateTime based on 
UInt32 and UUID based on UInt128
+static IAggregateFunction* create_with_decimal_type(const IDataType& 
argument_type,
+                                                    TArgs&&... args) {
     WhichDataType which(argument_type);
-    if (which.idx == TypeIndex::Date) {
-        return new 
AggregateFunctionTemplate<UInt16>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::DateTime) {
-        return new 
AggregateFunctionTemplate<UInt32>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::UUID) {
-        return new 
AggregateFunctionTemplate<UInt128>(std::forward<TArgs>(args)...);
-    }
+#define DISPATCH(TYPE)                \
+    if (which.idx == TypeIndex::TYPE) \
+        return new 
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
+    FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
     return nullptr;
 }
 
 template <template <typename> class AggregateFunctionTemplate, typename... 
TArgs>
-static IAggregateFunction* create_with_decimal_type(const IDataType& 
argument_type,
-                                                    TArgs&&... args) {
-    WhichDataType which(argument_type);
-    if (which.idx == TypeIndex::Decimal32) {
-        return new 
AggregateFunctionTemplate<Decimal32>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Decimal64) {
-        return new 
AggregateFunctionTemplate<Decimal64>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Decimal128) {
-        return new 
AggregateFunctionTemplate<Decimal128>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Decimal128I) {
-        return new 
AggregateFunctionTemplate<Decimal128I>(std::forward<TArgs>(args)...);
-    }
+static IAggregateFunction* create_with_decimal_type_null(const DataTypes& 
argument_types,
+                                                         const Array& params, 
TArgs&&... args) {
+    WhichDataType which(argument_types[0]);
+#define DISPATCH(TYPE)                                                         
                    \
+    if (which.idx == TypeIndex::TYPE)                                          
                    \
+        return new 
AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>(        
\
+                new 
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \
+                params);
+    FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
     return nullptr;
 }
 
@@ -197,18 +157,11 @@ template <template <typename, typename> class 
AggregateFunctionTemplate, typenam
 static IAggregateFunction* create_with_decimal_type(const IDataType& 
argument_type,
                                                     TArgs&&... args) {
     WhichDataType which(argument_type);
-    if (which.idx == TypeIndex::Decimal32) {
-        return new AggregateFunctionTemplate<Decimal32, 
Data>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Decimal64) {
-        return new AggregateFunctionTemplate<Decimal64, 
Data>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Decimal128) {
-        return new AggregateFunctionTemplate<Decimal128, 
Data>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Decimal128I) {
-        return new AggregateFunctionTemplate<Decimal128I, 
Data>(std::forward<TArgs>(args)...);
-    }
+#define DISPATCH(TYPE)                \
+    if (which.idx == TypeIndex::TYPE) \
+        return new AggregateFunctionTemplate<TYPE, 
Data>(std::forward<TArgs>(args)...);
+    FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
     return nullptr;
 }
 
@@ -224,12 +177,6 @@ static IAggregateFunction* 
create_with_two_numeric_types_second(const IDataType&
         return new AggregateFunctionTemplate<FirstType, 
TYPE>(std::forward<TArgs>(args)...);
     FOR_NUMERIC_TYPES(DISPATCH)
 #undef DISPATCH
-    if (which.idx == TypeIndex::Enum8) {
-        return new AggregateFunctionTemplate<FirstType, 
Int8>(std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Enum16) {
-        return new AggregateFunctionTemplate<FirstType, 
Int16>(std::forward<TArgs>(args)...);
-    }
     return nullptr;
 }
 
@@ -244,14 +191,6 @@ static IAggregateFunction* 
create_with_two_numeric_types(const IDataType& first_
                 second_type, std::forward<TArgs>(args)...);
     FOR_NUMERIC_TYPES(DISPATCH)
 #undef DISPATCH
-    if (which.idx == TypeIndex::Enum8) {
-        return create_with_two_numeric_types_second<Int8, 
AggregateFunctionTemplate>(
-                second_type, std::forward<TArgs>(args)...);
-    }
-    if (which.idx == TypeIndex::Enum16) {
-        return create_with_two_numeric_types_second<Int16, 
AggregateFunctionTemplate>(
-                second_type, std::forward<TArgs>(args)...);
-    }
     return nullptr;
 }
 
diff --git a/be/src/vec/data_types/data_type_nullable.cpp 
b/be/src/vec/data_types/data_type_nullable.cpp
index 6f69145504..e86cf77a79 100644
--- a/be/src/vec/data_types/data_type_nullable.cpp
+++ b/be/src/vec/data_types/data_type_nullable.cpp
@@ -158,4 +158,16 @@ DataTypePtr remove_nullable(const DataTypePtr& type) {
     return type;
 }
 
+DataTypes remove_nullable(const DataTypes& types) {
+    DataTypes no_null_types;
+    for (auto& type : types) {
+        if (type->is_nullable()) {
+            no_null_types.push_back(static_cast<const 
DataTypeNullable&>(*type).get_nested_type());
+        } else {
+            no_null_types.push_back(type);
+        }
+    }
+    return no_null_types;
+}
+
 } // namespace doris::vectorized
diff --git a/be/src/vec/data_types/data_type_nullable.h 
b/be/src/vec/data_types/data_type_nullable.h
index 32488ca35e..d8e6bf22b2 100644
--- a/be/src/vec/data_types/data_type_nullable.h
+++ b/be/src/vec/data_types/data_type_nullable.h
@@ -93,5 +93,6 @@ private:
 
 DataTypePtr make_nullable(const DataTypePtr& type);
 DataTypePtr remove_nullable(const DataTypePtr& type);
+DataTypes remove_nullable(const DataTypes& types);
 
 } // namespace doris::vectorized


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to