This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch dev-1.0.1 in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
commit d6be70c13043e42e8f3581cfbe5291b504afc376 Author: HappenLee <happen...@hotmail.com> AuthorDate: Thu Mar 24 11:39:29 2022 +0800 [fix] Fix coredump of stddev function (#8543) This is only a temporary fix its performance is not ideal. Finally, we need to reconstruct the functions of `stddev` and delete the interface of `insert_to_null_default ()`. --- .../aggregate_function_stddev.cpp | 7 +++-- .../aggregate_function_stddev.h | 32 ++++++++++++++++------ .../apache/doris/catalog/AggregateFunction.java | 6 +++- .../java/org/apache/doris/catalog/FunctionSet.java | 1 + 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp index f1794d6..2b06423 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp @@ -90,11 +90,14 @@ AggregateFunctionPtr create_aggregate_function_stddev_pop(const std::string& nam void register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& factory) { factory.register_function("variance_samp", create_aggregate_function_variance_samp<false>); - factory.register_function("variance", create_aggregate_function_variance_pop<false>); + factory.register_function("variance_samp", create_aggregate_function_variance_samp<false>, true); + factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true>); + factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true>, true); factory.register_alias("variance_samp", "var_samp"); + + factory.register_function("variance", create_aggregate_function_variance_pop<false>); factory.register_alias("variance", "var_pop"); factory.register_alias("variance", "variance_pop"); - factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true>); factory.register_function("stddev", create_aggregate_function_stddev_pop<true>); factory.register_alias("stddev", "stddev_pop"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h b/be/src/vec/aggregate_functions/aggregate_function_stddev.h index 82e8718..50c4064 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h @@ -69,7 +69,7 @@ struct BaseData { } static const DataTypePtr get_return_type() { - return make_nullable(std::make_shared<DataTypeNumber<Float64>>()); + return std::make_shared<DataTypeNumber<Float64>>(); } void merge(const BaseData& rhs) { @@ -83,7 +83,7 @@ struct BaseData { count = sum_count; } - void add(const IColumn** columns, size_t row_num) { + virtual void add(const IColumn** columns, size_t row_num) { const auto& sources = static_cast<const ColumnVector<T>&>(*columns[0]); double source_data = sources.get_data()[row_num]; @@ -145,7 +145,7 @@ struct BaseDatadecimal { } static const DataTypePtr get_return_type() { - return make_nullable(std::make_shared<DataTypeDecimal<Decimal128>>(27, 9)); + return std::make_shared<DataTypeDecimal<Decimal128>>(27, 9); } void merge(const BaseDatadecimal& rhs) { @@ -164,7 +164,7 @@ struct BaseDatadecimal { count += rhs.count; } - void add(const IColumn** columns, size_t row_num) { + virtual void add(const IColumn** columns, size_t row_num) { DecimalV2Value source_data = DecimalV2Value(); const auto& sources = static_cast<const ColumnDecimal<Decimal128>&>(*columns[0]); source_data = (DecimalV2Value)sources.get_data()[row_num]; @@ -191,14 +191,12 @@ struct PopData : Data { using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<Float64>>; void insert_result_into(IColumn& to) const { - ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); - auto& col = static_cast<ColVecResult&>(nullable_column.get_nested_column()); + auto& col = assert_cast<ColVecResult&>(to); if constexpr (IsDecimalNumber<T>) { col.get_data().push_back(this->get_pop_result().value()); } else { col.get_data().push_back(this->get_pop_result()); } - nullable_column.get_null_map_data().push_back(0); } }; @@ -220,6 +218,24 @@ struct SampData : Data { nullable_column.get_null_map_data().push_back(0); } } + + static const DataTypePtr get_return_type() { + return make_nullable(Data::get_return_type()); + } + + void add(const IColumn** columns, size_t row_num) override { + if (columns[0]->is_nullable()) { + const auto& nullable_column = assert_cast<const ColumnNullable&>(*columns[0]); + if (!nullable_column.is_null_at(row_num)) { + const IColumn* new_columns[1]; + new_columns[0] = &nullable_column.get_nested_column(); + Data::add(new_columns, row_num); + } + } else { + Data::add(columns, row_num); + } + } + }; template <typename Data> @@ -252,8 +268,6 @@ public: String get_name() const override { return Data::name(); } - bool insert_to_null_default() const override { return false; } - DataTypePtr get_return_type() const override { return Data::get_return_type(); } void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index da4d201..da8dc10 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -51,6 +51,9 @@ public class AggregateFunction extends Function { public static ImmutableSet<String> NOT_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("row_number", "rank", "dense_rank", "hll_union_agg", "hll_union", "bitmap_union", "bitmap_intersect", FunctionSet.COUNT, "ndv", FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize"); + public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = + ImmutableSet.of("stddev_samp", "variance_samp", "var_samp"); + // Set if different from retType_, null otherwise. private Type intermediateType; @@ -150,7 +153,8 @@ public class AggregateFunction extends Function { String removeFnSymbol, String finalizeFnSymbol, boolean vectorized) { // only `count` is always not nullable, other aggregate function is always nullable super(fnName, argTypes, retType, hasVarArgs, vectorized, - AggregateFunction.NOT_NULLABLE_AGGREGATE_FUNCTION_NAME_SET.contains(fnName.getFunction()) ? NullableMode.ALWAYS_NOT_NULLABLE : NullableMode.DEPEND_ON_ARGUMENT); + AggregateFunction.NOT_NULLABLE_AGGREGATE_FUNCTION_NAME_SET.contains(fnName.getFunction()) ? NullableMode.ALWAYS_NOT_NULLABLE : + AggregateFunction.ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET.contains(fnName.getFunction()) ? NullableMode.ALWAYS_NULLABLE : NullableMode.DEPEND_ON_ARGUMENT); setLocation(location); this.intermediateType = (intermediateType.equals(retType)) ? null : intermediateType; this.updateFnSymbol = updateFnSymbol; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 9d49cc5..afe7983 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1665,6 +1665,7 @@ public class FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo null, prefix + STDDEV_POP_FINALIZE_SYMBOL.get(t), false, false, false)); + //vec stddev stddev_samp stddev_pop addBuiltin(AggregateFunction.createBuiltin("stddev", Lists.newArrayList(t), STDDEV_RETTYPE_SYMBOL.get(t), t, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org