This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch dev-1.0.1
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git

commit d6be70c13043e42e8f3581cfbe5291b504afc376
Author: HappenLee <happen...@hotmail.com>
AuthorDate: Thu Mar 24 11:39:29 2022 +0800

    [fix] Fix coredump of stddev function (#8543)
    
    This is only a temporary fix its performance is not ideal. Finally,
    we need to reconstruct the functions of `stddev` and delete the interface 
of `insert_to_null_default ()`.
---
 .../aggregate_function_stddev.cpp                  |  7 +++--
 .../aggregate_function_stddev.h                    | 32 ++++++++++++++++------
 .../apache/doris/catalog/AggregateFunction.java    |  6 +++-
 .../java/org/apache/doris/catalog/FunctionSet.java |  1 +
 4 files changed, 34 insertions(+), 12 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
index f1794d6..2b06423 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
@@ -90,11 +90,14 @@ AggregateFunctionPtr 
create_aggregate_function_stddev_pop(const std::string& nam
 
 void 
register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& 
factory) {
     factory.register_function("variance_samp", 
create_aggregate_function_variance_samp<false>);
-    factory.register_function("variance", 
create_aggregate_function_variance_pop<false>);
+    factory.register_function("variance_samp", 
create_aggregate_function_variance_samp<false>, true);
+    factory.register_function("stddev_samp", 
create_aggregate_function_stddev_samp<true>);
+    factory.register_function("stddev_samp", 
create_aggregate_function_stddev_samp<true>, true);
     factory.register_alias("variance_samp", "var_samp");
+
+    factory.register_function("variance", 
create_aggregate_function_variance_pop<false>);
     factory.register_alias("variance", "var_pop");
     factory.register_alias("variance", "variance_pop");
-    factory.register_function("stddev_samp", 
create_aggregate_function_stddev_samp<true>);
     factory.register_function("stddev", 
create_aggregate_function_stddev_pop<true>);
     factory.register_alias("stddev", "stddev_pop");
 }
diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h 
b/be/src/vec/aggregate_functions/aggregate_function_stddev.h
index 82e8718..50c4064 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h
@@ -69,7 +69,7 @@ struct BaseData {
     }
 
     static const DataTypePtr get_return_type() {
-        return make_nullable(std::make_shared<DataTypeNumber<Float64>>());
+        return std::make_shared<DataTypeNumber<Float64>>();
     }
 
     void merge(const BaseData& rhs) {
@@ -83,7 +83,7 @@ struct BaseData {
         count = sum_count;
     }
 
-    void add(const IColumn** columns, size_t row_num) {
+    virtual void add(const IColumn** columns, size_t row_num) {
         const auto& sources = static_cast<const ColumnVector<T>&>(*columns[0]);
         double source_data = sources.get_data()[row_num];
 
@@ -145,7 +145,7 @@ struct BaseDatadecimal {
     }
 
     static const DataTypePtr get_return_type() {
-        return make_nullable(std::make_shared<DataTypeDecimal<Decimal128>>(27, 
9));
+        return std::make_shared<DataTypeDecimal<Decimal128>>(27, 9);
     }
 
     void merge(const BaseDatadecimal& rhs) {
@@ -164,7 +164,7 @@ struct BaseDatadecimal {
         count += rhs.count;
     }
 
-    void add(const IColumn** columns, size_t row_num) {
+    virtual void add(const IColumn** columns, size_t row_num) {
         DecimalV2Value source_data = DecimalV2Value();
         const auto& sources = static_cast<const 
ColumnDecimal<Decimal128>&>(*columns[0]);
         source_data = (DecimalV2Value)sources.get_data()[row_num];
@@ -191,14 +191,12 @@ struct PopData : Data {
     using ColVecResult = std::conditional_t<IsDecimalNumber<T>, 
ColumnDecimal<Decimal128>,
                                             ColumnVector<Float64>>;
     void insert_result_into(IColumn& to) const {
-        ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to);
-        auto& col = 
static_cast<ColVecResult&>(nullable_column.get_nested_column());
+        auto& col = assert_cast<ColVecResult&>(to);
         if constexpr (IsDecimalNumber<T>) {
             col.get_data().push_back(this->get_pop_result().value());
         } else {
             col.get_data().push_back(this->get_pop_result());
         }
-        nullable_column.get_null_map_data().push_back(0);
     }
 };
 
@@ -220,6 +218,24 @@ struct SampData : Data {
             nullable_column.get_null_map_data().push_back(0);
         }
     }
+
+    static const DataTypePtr get_return_type() {
+        return make_nullable(Data::get_return_type());
+    }
+
+    void add(const IColumn** columns, size_t row_num) override {
+        if (columns[0]->is_nullable()) {
+            const auto& nullable_column = assert_cast<const 
ColumnNullable&>(*columns[0]);
+            if (!nullable_column.is_null_at(row_num)) {
+                const IColumn* new_columns[1];
+                new_columns[0] = &nullable_column.get_nested_column();
+                Data::add(new_columns, row_num);
+            }
+        } else {
+            Data::add(columns, row_num);
+        }
+    }
+
 };
 
 template <typename Data>
@@ -252,8 +268,6 @@ public:
 
     String get_name() const override { return Data::name(); }
 
-    bool insert_to_null_default() const override { return false; }
-
     DataTypePtr get_return_type() const override { return 
Data::get_return_type(); }
 
     void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
index da4d201..da8dc10 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
@@ -51,6 +51,9 @@ public class AggregateFunction extends Function {
     public static ImmutableSet<String> 
NOT_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
             ImmutableSet.of("row_number", "rank", "dense_rank", 
"hll_union_agg", "hll_union", "bitmap_union", "bitmap_intersect", 
FunctionSet.COUNT, "ndv", FunctionSet.BITMAP_UNION_INT, 
FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize");
 
+    public static ImmutableSet<String> 
ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
+            ImmutableSet.of("stddev_samp", "variance_samp", "var_samp");
+
     // Set if different from retType_, null otherwise.
     private Type intermediateType;
 
@@ -150,7 +153,8 @@ public class AggregateFunction extends Function {
                              String removeFnSymbol, String finalizeFnSymbol, 
boolean vectorized) {
         // only `count` is always not nullable, other aggregate function is 
always nullable
         super(fnName, argTypes, retType, hasVarArgs, vectorized,
-                
AggregateFunction.NOT_NULLABLE_AGGREGATE_FUNCTION_NAME_SET.contains(fnName.getFunction())
 ? NullableMode.ALWAYS_NOT_NULLABLE : NullableMode.DEPEND_ON_ARGUMENT);
+                
AggregateFunction.NOT_NULLABLE_AGGREGATE_FUNCTION_NAME_SET.contains(fnName.getFunction())
 ? NullableMode.ALWAYS_NOT_NULLABLE :
+                
AggregateFunction.ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET.contains(fnName.getFunction())
 ? NullableMode.ALWAYS_NULLABLE : NullableMode.DEPEND_ON_ARGUMENT);
         setLocation(location);
         this.intermediateType = (intermediateType.equals(retType)) ? null : 
intermediateType;
         this.updateFnSymbol = updateFnSymbol;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index 9d49cc5..afe7983 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1665,6 +1665,7 @@ public class 
FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo
                         null,
                         prefix + STDDEV_POP_FINALIZE_SYMBOL.get(t),
                         false, false, false));
+
                 //vec stddev stddev_samp stddev_pop
                 addBuiltin(AggregateFunction.createBuiltin("stddev",
                         Lists.newArrayList(t), STDDEV_RETTYPE_SYMBOL.get(t), t,

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to