This is an automated email from the ASF dual-hosted git repository. lihaopeng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 24ef60b491 [Opt](exec) opt aggreate function performance in nullable column 24ef60b491 is described below commit 24ef60b491801e91ff37a6b34642d47f8ab46604 Author: HappenLee <happen...@hotmail.com> AuthorDate: Thu Feb 16 22:26:12 2023 +0800 [Opt](exec) opt aggreate function performance in nullable column --- .../aggregate_functions/aggregate_function_avg.cpp | 21 +- .../aggregate_functions/aggregate_function_count.h | 3 +- .../aggregate_function_min_max.cpp | 1 - .../aggregate_function_null.cpp | 1 - .../aggregate_functions/aggregate_function_null.h | 245 +++++++++++++++++++++ .../aggregate_functions/aggregate_function_sum.cpp | 24 +- be/src/vec/aggregate_functions/helpers.h | 149 ++++--------- be/src/vec/data_types/data_type_nullable.cpp | 12 + be/src/vec/data_types/data_type_nullable.h | 1 + 9 files changed, 339 insertions(+), 118 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 5875b831f3..7f9295d8e7 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -45,11 +45,23 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, AggregateFunctionPtr res; DataTypePtr data_type = argument_types[0]; - if (is_decimal(data_type)) { - res.reset( - create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type, argument_types)); + if (data_type->is_nullable()) { + auto no_null_argument_types = remove_nullable(argument_types); + if (is_decimal(no_null_argument_types[0])) { + res.reset(create_with_decimal_type_null<AggregateFuncAvg>( + no_null_argument_types, parameters, *no_null_argument_types[0], + no_null_argument_types)); + } else { + res.reset(create_with_numeric_type_null<AggregateFuncAvg>( + no_null_argument_types, parameters, no_null_argument_types)); + } } else { - res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types)); + if (is_decimal(data_type)) { + res.reset(create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type, + argument_types)); + } else { + res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types)); + } } if (!res) { @@ -61,5 +73,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { factory.register_function("avg", create_aggregate_function_avg); + factory.register_function("avg", create_aggregate_function_avg, true); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h b/be/src/vec/aggregate_functions/aggregate_function_count.h index 960d4111cb..bc87e4bb10 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.h +++ b/be/src/vec/aggregate_functions/aggregate_function_count.h @@ -121,7 +121,8 @@ public: DataTypePtr get_serialized_type() const override { return std::make_shared<DataTypeUInt64>(); } }; -/// Simply count number of not-NULL values. +// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of AggregateFunctionCount +// Simply count number of not-NULL values. class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary> { diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index 83045dbd00..a01e2ce51a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -25,7 +25,6 @@ #include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { - /// min, max, any template <template <typename, bool> class AggregateFunctionTemplate, template <typename> class Data> static IAggregateFunction* create_aggregate_function_single_value(const String& name, diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.cpp b/be/src/vec/aggregate_functions/aggregate_function_null.cpp index 495cefcb84..8ae2368864 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_null.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_null.cpp @@ -85,7 +85,6 @@ public: }; void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory) { - // factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorNull>()); AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, const Array& params, const bool result_is_nullable) { auto function_combinator = std::make_shared<AggregateFunctionCombinatorNull>(); diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h b/be/src/vec/aggregate_functions/aggregate_function_null.h index 86fe7734e1..69642d0deb 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_null.h +++ b/be/src/vec/aggregate_functions/aggregate_function_null.h @@ -40,6 +40,7 @@ namespace doris::vectorized { /// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter. /// true - return NULL; false - return value from empty aggregation state of nested function. +// TODO: only keep class xxxInline after we support all aggregate function template <bool result_is_nullable, typename Derived> class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived> { protected: @@ -409,4 +410,248 @@ private: is_nullable; /// Plain array is better than std::vector due to one indirection less. }; +template <typename NestFunction, bool result_is_nullable, typename Derived> +class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived> { +protected: + std::unique_ptr<NestFunction> nested_function; + size_t prefix_size; + + /** In addition to data for nested aggregate function, we keep a flag + * indicating - was there at least one non-NULL value accumulated. + * In case of no not-NULL values, the function will return NULL. + * + * We use prefix_size bytes for flag to satisfy the alignment requirement of nested state. + */ + + AggregateDataPtr nested_place(AggregateDataPtr __restrict place) const noexcept { + return place + prefix_size; + } + + ConstAggregateDataPtr nested_place(ConstAggregateDataPtr __restrict place) const noexcept { + return place + prefix_size; + } + + static void init_flag(AggregateDataPtr __restrict place) noexcept { + if constexpr (result_is_nullable) { + place[0] = false; + } + } + + static void set_flag(AggregateDataPtr __restrict place) noexcept { + if constexpr (result_is_nullable) { + place[0] = true; + } + } + + static bool get_flag(ConstAggregateDataPtr __restrict place) noexcept { + return result_is_nullable ? place[0] : true; + } + +public: + AggregateFunctionNullBaseInline(IAggregateFunction* nested_function_, + const DataTypes& arguments, const Array& params) + : IAggregateFunctionHelper<Derived>(arguments, params), + nested_function {assert_cast<NestFunction*>(nested_function_)} { + if (result_is_nullable) { + prefix_size = nested_function->align_of_data(); + } else { + prefix_size = 0; + } + } + + String get_name() const override { + /// This is just a wrapper. The function for Nullable arguments is named the same as the nested function itself. + return nested_function->get_name(); + } + + DataTypePtr get_return_type() const override { + return result_is_nullable ? make_nullable(nested_function->get_return_type()) + : nested_function->get_return_type(); + } + + void create(AggregateDataPtr __restrict place) const override { + init_flag(place); + nested_function->create(nested_place(place)); + } + + void destroy(AggregateDataPtr __restrict place) const noexcept override { + nested_function->destroy(nested_place(place)); + } + void reset(AggregateDataPtr place) const override { + init_flag(place); + nested_function->reset(nested_place(place)); + } + + bool has_trivial_destructor() const override { + return nested_function->has_trivial_destructor(); + } + + size_t size_of_data() const override { return prefix_size + nested_function->size_of_data(); } + + size_t align_of_data() const override { return nested_function->align_of_data(); } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena* arena) const override { + if (result_is_nullable && get_flag(rhs)) { + set_flag(place); + } + + nested_function->merge(nested_place(place), nested_place(rhs), arena); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + bool flag = get_flag(place); + if (result_is_nullable) { + write_binary(flag, buf); + } + if (flag) { + nested_function->serialize(nested_place(place), buf); + } + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena* arena) const override { + bool flag = true; + if (result_is_nullable) { + read_binary(flag, buf); + } + if (flag) { + set_flag(place); + nested_function->deserialize(nested_place(place), buf, arena); + } + } + + void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena* arena) const override { + bool flag = true; + if (result_is_nullable) { + read_binary(flag, buf); + } + if (flag) { + set_flag(place); + nested_function->deserialize_and_merge(nested_place(place), buf, arena); + } + } + + void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, + Arena* arena) const override { + size_t num_rows = column.size(); + for (size_t i = 0; i != num_rows; ++i) { + VectorBufferReader buffer_reader( + (assert_cast<const ColumnString&>(column)).get_data_at(i)); + deserialize_and_merge(place, buffer_reader, arena); + } + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + if constexpr (result_is_nullable) { + ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to); + if (get_flag(place)) { + nested_function->insert_result_into(nested_place(place), + to_concrete.get_nested_column()); + to_concrete.get_null_map_data().push_back(0); + } else { + to_concrete.insert_default(); + } + } else { + nested_function->insert_result_into(nested_place(place), to); + } + } + + bool allocates_memory_in_arena() const override { + return nested_function->allocates_memory_in_arena(); + } + + bool is_state() const override { return nested_function->is_state(); } +}; + +/** There are two cases: for single argument and variadic. + * Code for single argument is much more efficient. + */ +template <typename NestFuction, bool result_is_nullable> +class AggregateFunctionNullUnaryInline final + : public AggregateFunctionNullBaseInline< + NestFuction, result_is_nullable, + AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>> { +public: + AggregateFunctionNullUnaryInline(IAggregateFunction* nested_function_, + const DataTypes& arguments, const Array& params) + : AggregateFunctionNullBaseInline< + NestFuction, result_is_nullable, + AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>>( + nested_function_, arguments, params) {} + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena* arena) const override { + const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); + if (!column->is_null_at(row_num)) { + this->set_flag(place); + const IColumn* nested_column = &column->get_nested_column(); + this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena); + } + } + + void add_not_nullable(AggregateDataPtr __restrict place, const IColumn** columns, + size_t row_num, Arena* arena) const { + const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); + this->set_flag(place); + const IColumn* nested_column = &column->get_nested_column(); + this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena); + } + + void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset, + const IColumn** columns, Arena* arena, bool agg_many) const override { + const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); + // The overhead introduced is negligible here, just an extra memory read from NullMap + const auto* __restrict null_map_data = column->get_null_map_data().data(); + const IColumn* nested_column = &column->get_nested_column(); + for (int i = 0; i < batch_size; ++i) { + if (!null_map_data[i]) { + AggregateDataPtr __restrict place = places[i] + place_offset; + this->set_flag(place); + this->nested_function->add(this->nested_place(place), &nested_column, i, arena); + } + } + } + + void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, + Arena* arena) const override { + const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); + bool has_null = column->has_null(); + + if (has_null) { + for (size_t i = 0; i < batch_size; ++i) { + if (!column->is_null_at(i)) { + this->set_flag(place); + this->add(place, columns, i, arena); + } + } + } else { + this->set_flag(place); + const IColumn* nested_column = &column->get_nested_column(); + this->nested_function->add_batch_single_place(batch_size, this->nested_place(place), + &nested_column, arena); + } + } + + void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place, + const IColumn** columns, Arena* arena, bool has_null) override { + const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); + + if (has_null) { + for (size_t i = batch_begin; i <= batch_end; ++i) { + if (!column->is_null_at(i)) { + this->set_flag(place); + this->add(place, columns, i, arena); + } + } + } else { + this->set_flag(place); + const IColumn* nested_column = &column->get_nested_column(); + this->nested_function->add_batch_range(batch_begin, batch_end, + this->nested_place(place), &nested_column, arena, + false); + } + } +}; } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp index ca40e4196c..75d4d36414 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp @@ -25,6 +25,7 @@ #include "common/logging.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/helpers.h" +#include "vec/data_types/data_type_nullable.h" namespace doris::vectorized { @@ -45,15 +46,24 @@ AggregateFunctionPtr create_aggregate_function_sum(const std::string& name, const DataTypes& argument_types, const Array& parameters, const bool result_is_nullable) { - // assert_no_parameters(name, parameters); - // assert_unary(name, argument_types); - AggregateFunctionPtr res; DataTypePtr data_type = argument_types[0]; - if (is_decimal(data_type)) { - res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types)); + if (data_type->is_nullable()) { + auto no_null_argument_types = remove_nullable(argument_types); + if (is_decimal(no_null_argument_types[0])) { + res.reset(create_with_decimal_type_null<Function>(no_null_argument_types, parameters, + *no_null_argument_types[0], + no_null_argument_types)); + } else { + res.reset(create_with_numeric_type_null<Function>(no_null_argument_types, parameters, + no_null_argument_types)); + } } else { - res.reset(create_with_numeric_type<Function>(*data_type, argument_types)); + if (is_decimal(data_type)) { + res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types)); + } else { + res.reset(create_with_numeric_type<Function>(*data_type, argument_types)); + } } if (!res) { @@ -84,6 +94,8 @@ AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string& nam void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) { factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>); + factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>, + true); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/helpers.h b/be/src/vec/aggregate_functions/helpers.h index 36e11f7011..0970a860a6 100644 --- a/be/src/vec/aggregate_functions/helpers.h +++ b/be/src/vec/aggregate_functions/helpers.h @@ -21,8 +21,10 @@ #pragma once #include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_null.h" #include "vec/data_types/data_type.h" +// TODO: Should we support decimal in numeric types? #define FOR_NUMERIC_TYPES(M) \ M(UInt8) \ M(UInt16) \ @@ -36,6 +38,12 @@ M(Float32) \ M(Float64) +#define FOR_DECIMAL_TYPES(M) \ + M(Decimal32) \ + M(Decimal64) \ + M(Decimal128) \ + M(Decimal128I) + namespace doris::vectorized { /** Create an aggregate function with a numeric type in the template parameter, depending on the type of the argument. @@ -49,12 +57,20 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - if (which.idx == TypeIndex::Enum8) { - return new AggregateFunctionTemplate<Int8>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Enum16) { - return new AggregateFunctionTemplate<Int16>(std::forward<TArgs>(args)...); - } + return nullptr; +} + +template <template <typename> class AggregateFunctionTemplate, typename... TArgs> +static IAggregateFunction* create_with_numeric_type_null(const DataTypes& argument_types, + const Array& params, TArgs&&... args) { + WhichDataType which(argument_types[0]); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \ + new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \ + params); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH return nullptr; } @@ -68,12 +84,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty return new AggregateFunctionTemplate<TYPE, bool_param>(std::forward<TArgs>(args)...); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - if (which.idx == TypeIndex::Enum8) { - return new AggregateFunctionTemplate<Int8, bool_param>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Enum16) { - return new AggregateFunctionTemplate<Int16, bool_param>(std::forward<TArgs>(args)...); - } return nullptr; } @@ -87,12 +97,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - if (which.idx == TypeIndex::Enum8) { - return new AggregateFunctionTemplate<Int8, Data>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Enum16) { - return new AggregateFunctionTemplate<Int16, Data>(std::forward<TArgs>(args)...); - } return nullptr; } @@ -106,12 +110,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty return new AggregateFunctionTemplate<TYPE, Data<TYPE>>(std::forward<TArgs>(args)...); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - if (which.idx == TypeIndex::Enum8) { - return new AggregateFunctionTemplate<Int8, Data<Int8>>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Enum16) { - return new AggregateFunctionTemplate<Int16, Data<Int16>>(std::forward<TArgs>(args)...); - } return nullptr; } @@ -125,70 +123,32 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty return new AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - // if (which.idx == TypeIndex::Enum8) return new AggregateFunctionTemplate<Data<Int8>>(std::forward<TArgs>(args)...); - // if (which.idx == TypeIndex::Enum16) return new AggregateFunctionTemplate<Data<Int16>>(std::forward<TArgs>(args)...); - return nullptr; -} - -template <template <typename, typename> class AggregateFunctionTemplate, - template <typename> class Data, typename... TArgs> -static IAggregateFunction* create_with_unsigned_integer_type(const IDataType& argument_type, - TArgs&&... args) { - WhichDataType which(argument_type); - if (which.idx == TypeIndex::UInt8) { - return new AggregateFunctionTemplate<UInt8, Data<UInt8>>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::UInt16) { - return new AggregateFunctionTemplate<UInt16, Data<UInt16>>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::UInt32) { - return new AggregateFunctionTemplate<UInt32, Data<UInt32>>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::UInt64) { - return new AggregateFunctionTemplate<UInt64, Data<UInt64>>(std::forward<TArgs>(args)...); - } return nullptr; } template <template <typename> class AggregateFunctionTemplate, typename... TArgs> -static IAggregateFunction* create_with_numeric_based_type(const IDataType& argument_type, - TArgs&&... args) { - IAggregateFunction* f = create_with_numeric_type<AggregateFunctionTemplate>( - argument_type, std::forward<TArgs>(args)...); - if (f) { - return f; - } - - /// expects that DataTypeDate based on UInt16, DataTypeDateTime based on UInt32 and UUID based on UInt128 +static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type, + TArgs&&... args) { WhichDataType which(argument_type); - if (which.idx == TypeIndex::Date) { - return new AggregateFunctionTemplate<UInt16>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::DateTime) { - return new AggregateFunctionTemplate<UInt32>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::UUID) { - return new AggregateFunctionTemplate<UInt128>(std::forward<TArgs>(args)...); - } +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...); + FOR_DECIMAL_TYPES(DISPATCH) +#undef DISPATCH return nullptr; } template <template <typename> class AggregateFunctionTemplate, typename... TArgs> -static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type, - TArgs&&... args) { - WhichDataType which(argument_type); - if (which.idx == TypeIndex::Decimal32) { - return new AggregateFunctionTemplate<Decimal32>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Decimal64) { - return new AggregateFunctionTemplate<Decimal64>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Decimal128) { - return new AggregateFunctionTemplate<Decimal128>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Decimal128I) { - return new AggregateFunctionTemplate<Decimal128I>(std::forward<TArgs>(args)...); - } +static IAggregateFunction* create_with_decimal_type_null(const DataTypes& argument_types, + const Array& params, TArgs&&... args) { + WhichDataType which(argument_types[0]); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \ + new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \ + params); + FOR_DECIMAL_TYPES(DISPATCH) +#undef DISPATCH return nullptr; } @@ -197,18 +157,11 @@ template <template <typename, typename> class AggregateFunctionTemplate, typenam static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type, TArgs&&... args) { WhichDataType which(argument_type); - if (which.idx == TypeIndex::Decimal32) { - return new AggregateFunctionTemplate<Decimal32, Data>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Decimal64) { - return new AggregateFunctionTemplate<Decimal64, Data>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Decimal128) { - return new AggregateFunctionTemplate<Decimal128, Data>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Decimal128I) { - return new AggregateFunctionTemplate<Decimal128I, Data>(std::forward<TArgs>(args)...); - } +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...); + FOR_DECIMAL_TYPES(DISPATCH) +#undef DISPATCH return nullptr; } @@ -224,12 +177,6 @@ static IAggregateFunction* create_with_two_numeric_types_second(const IDataType& return new AggregateFunctionTemplate<FirstType, TYPE>(std::forward<TArgs>(args)...); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - if (which.idx == TypeIndex::Enum8) { - return new AggregateFunctionTemplate<FirstType, Int8>(std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Enum16) { - return new AggregateFunctionTemplate<FirstType, Int16>(std::forward<TArgs>(args)...); - } return nullptr; } @@ -244,14 +191,6 @@ static IAggregateFunction* create_with_two_numeric_types(const IDataType& first_ second_type, std::forward<TArgs>(args)...); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - if (which.idx == TypeIndex::Enum8) { - return create_with_two_numeric_types_second<Int8, AggregateFunctionTemplate>( - second_type, std::forward<TArgs>(args)...); - } - if (which.idx == TypeIndex::Enum16) { - return create_with_two_numeric_types_second<Int16, AggregateFunctionTemplate>( - second_type, std::forward<TArgs>(args)...); - } return nullptr; } diff --git a/be/src/vec/data_types/data_type_nullable.cpp b/be/src/vec/data_types/data_type_nullable.cpp index 6f69145504..e86cf77a79 100644 --- a/be/src/vec/data_types/data_type_nullable.cpp +++ b/be/src/vec/data_types/data_type_nullable.cpp @@ -158,4 +158,16 @@ DataTypePtr remove_nullable(const DataTypePtr& type) { return type; } +DataTypes remove_nullable(const DataTypes& types) { + DataTypes no_null_types; + for (auto& type : types) { + if (type->is_nullable()) { + no_null_types.push_back(static_cast<const DataTypeNullable&>(*type).get_nested_type()); + } else { + no_null_types.push_back(type); + } + } + return no_null_types; +} + } // namespace doris::vectorized diff --git a/be/src/vec/data_types/data_type_nullable.h b/be/src/vec/data_types/data_type_nullable.h index 32488ca35e..d8e6bf22b2 100644 --- a/be/src/vec/data_types/data_type_nullable.h +++ b/be/src/vec/data_types/data_type_nullable.h @@ -93,5 +93,6 @@ private: DataTypePtr make_nullable(const DataTypePtr& type); DataTypePtr remove_nullable(const DataTypePtr& type); +DataTypes remove_nullable(const DataTypes& types); } // namespace doris::vectorized --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org