This is an automated email from the ASF dual-hosted git repository.

weixiang pushed a commit to branch quantile_state_vec
in repository https://gitbox.apache.org/repos/asf/doris.git

commit b36fdd697c33302529320a5a9573f777181ad285
Author: spaces-x <weixian...@meituan.com>
AuthorDate: Sun Mar 12 22:50:44 2023 +0800

    fix input column is nullable in QuantileUnion
---
 .../aggregate_function_quantile_state.cpp          | 11 +++++++---
 .../aggregate_function_quantile_state.h            | 25 ++++++++++++++++------
 2 files changed, 26 insertions(+), 10 deletions(-)

diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.cpp
index 28684b2230..4c8ec27296 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.cpp
@@ -24,9 +24,14 @@ namespace doris::vectorized {
 AggregateFunctionPtr create_aggregate_function_quantile_state_union(const 
std::string& name,
                                                                     const 
DataTypes& argument_types,
                                                                     const bool 
result_is_nullable) {
-    return std::make_shared<
-            
AggregateFunctionQuantileStateOp<AggregateFunctionQuantileStateUnionOp, 
double>>(
-            argument_types);
+    const bool arg_is_nullable = argument_types[0]->is_nullable();
+    if (arg_is_nullable) {
+        return std::make_shared<AggregateFunctionQuantileStateOp<
+                true, AggregateFunctionQuantileStateUnionOp, 
double>>(argument_types);
+    } else {
+        return std::make_shared<AggregateFunctionQuantileStateOp<
+                false, AggregateFunctionQuantileStateUnionOp, 
double>>(argument_types);
+    }
 }
 
 void 
register_aggregate_function_quantile_state(AggregateFunctionSimpleFactory& 
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h 
b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h
index 5ee1d603da..6b07f79648 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h
@@ -83,10 +83,11 @@ struct AggregateFunctionQuantileStateData {
     DataType& get() { return value; }
 };
 
-template <typename Op, typename InternalType>
+template <bool arg_is_nullable, typename Op, typename InternalType>
 class AggregateFunctionQuantileStateOp final
-        : public 
IAggregateFunctionDataHelper<AggregateFunctionQuantileStateData<Op, 
InternalType>,
-                                              
AggregateFunctionQuantileStateOp<Op, InternalType>> {
+        : public IAggregateFunctionDataHelper<
+                  AggregateFunctionQuantileStateData<Op, InternalType>,
+                  AggregateFunctionQuantileStateOp<arg_is_nullable, Op, 
InternalType>> {
 public:
     using ResultDataType = QuantileState<InternalType>;
     using ColVecType = ColumnQuantileState<InternalType>;
@@ -95,8 +96,9 @@ public:
     String get_name() const override { return Op::name; }
 
     AggregateFunctionQuantileStateOp(const DataTypes& argument_types_)
-            : 
IAggregateFunctionDataHelper<AggregateFunctionQuantileStateData<Op, 
InternalType>,
-                                           
AggregateFunctionQuantileStateOp<Op, InternalType>>(
+            : IAggregateFunctionDataHelper<
+                      AggregateFunctionQuantileStateData<Op, InternalType>,
+                      AggregateFunctionQuantileStateOp<arg_is_nullable, Op, 
InternalType>>(
                       argument_types_) {}
 
     DataTypePtr get_return_type() const override {
@@ -105,8 +107,17 @@ public:
 
     void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
              Arena*) const override {
-        const auto& column = static_cast<const ColVecType&>(*columns[0]);
-        this->data(place).add(column.get_data()[row_num]);
+        if constexpr (arg_is_nullable) {
+            auto& nullable_column = assert_cast<const 
ColumnNullable&>(*columns[0]);
+            if (!nullable_column.is_null_at(row_num)) {
+                const auto& column =
+                        static_cast<const 
ColVecType&>(nullable_column.get_nested_column());
+                this->data(place).add(column.get_data()[row_num]);
+            }
+        } else {
+            const auto& column = static_cast<const ColVecType&>(*columns[0]);
+            this->data(place).add(column.get_data()[row_num]);
+        }
     }
 
     void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to