This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch vectorized
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git

commit f3ce1cba69aa27f8ae0af63451bab41d5048c367
Author: HappenLee <happen...@hotmail.com>
AuthorDate: Wed Jan 12 13:20:48 2022 +0800

    [Vectorized][Enhancement] use simd to speed up coalesce and if_not_null 
function (#7722)
    
    Co-authored-by: lihaopeng <lihaop...@baidu.com>
---
 be/src/vec/functions/function_coalesce.cpp     | 210 ++++++++++++++++++++-----
 be/src/vec/functions/is_not_null.cpp           |   4 +-
 be/src/vec/functions/simple_function_factory.h |   5 +-
 be/test/vec/function/function_string_test.cpp  |  42 +++++
 4 files changed, 214 insertions(+), 47 deletions(-)

diff --git a/be/src/vec/functions/function_coalesce.cpp 
b/be/src/vec/functions/function_coalesce.cpp
index 65d544c..99b6110 100644
--- a/be/src/vec/functions/function_coalesce.cpp
+++ b/be/src/vec/functions/function_coalesce.cpp
@@ -28,6 +28,8 @@ class FunctionCoalesce : public IFunction {
 public:
     static constexpr auto name = "coalesce";
 
+    mutable FunctionBasePtr func_is_not_null;
+
     static FunctionPtr create() { return std::make_shared<FunctionCoalesce>(); 
}
 
     String get_name() const override { return name; }
@@ -41,47 +43,70 @@ public:
     size_t get_number_of_arguments() const override { return 0; }
 
     DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        DataTypePtr res;
         for (const auto& arg : arguments) {
             if (!arg->is_nullable()) {
-                return arg;
+                res = arg;
+                break;
             }
         }
-        return arguments[0];
+
+        res = res ? res : arguments[0];
+
+        const ColumnsWithTypeAndName is_not_null_col{
+                {nullptr, make_nullable(res), ""}
+        };
+        func_is_not_null = SimpleFunctionFactory::instance().
+                get_function("is_not_null_pred", is_not_null_col, 
std::make_shared<DataTypeUInt8>());
+
+        return res;
     }
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) override {
         DCHECK_GE(arguments.size(), 1);
+        DataTypePtr result_type = block.get_by_position(result).type;
         ColumnNumbers filtered_args;
         filtered_args.reserve(arguments.size());
-        for (const auto& arg : arguments) {
-            const auto& type = block.get_by_position(arg).type;
-            if (type->only_null()) {
-                continue;
-            }
-            filtered_args.push_back(arg);
-            if (!type->is_nullable()) {
-                break;
+
+        for (size_t i = 0; i < arguments.size(); ++i) {
+            const auto& arg_type = block.get_by_position(arguments[i]).type;
+            filtered_args.push_back(arguments[i]);
+            if (!arg_type->is_nullable()) {
+                if (i == 0) { //if the first column not null, return it's 
directly
+                    block.get_by_position(result).column = 
block.get_by_position(arguments[0]).column;
+                    return Status::OK();
+                } else {
+                    break;
+                }
             }
         }
 
         size_t remaining_rows = input_rows_count;
         size_t argument_size = filtered_args.size();
-        std::vector<int> record_idx(input_rows_count, -1); //used to save 
column idx
+        std::vector<uint32_t> record_idx(input_rows_count, 0); //used to save 
column idx, record the result data of each row from which column
+        std::vector<uint8_t> filled_flags(input_rows_count, 0); //used to save 
filled flag, in order to check current row whether have filled data
+
         MutableColumnPtr result_column;
+        if (!result_type->is_nullable()) {
+            result_column = result_type->create_column();
+        } else {
+            result_column = remove_nullable(result_type)->create_column();
+        }
 
-        DataTypePtr type = block.get_by_position(result).type;
-        if (!type->is_nullable()) {
-            result_column = type->create_column();
+        // because now the string types does not support random position 
writing,
+        // so insert into result data have two methods, one is for string 
types, one is for others type remaining
+        bool is_string_result = result_column->is_column_string();
+        if (is_string_result) {
+            result_column->reserve(input_rows_count);
         } else {
-            result_column = remove_nullable(type)->create_column();
+            result_column->resize(input_rows_count);
         }
 
-        result_column->reserve(input_rows_count);
         auto return_type = std::make_shared<DataTypeUInt8>();
-        auto null_map = ColumnUInt8::create(input_rows_count, 1);
-        auto& null_map_data = null_map->get_data();
-        ColumnPtr argument_columns[argument_size];
+        auto null_map = ColumnUInt8::create(input_rows_count, 1);  //if 
null_map_data==1, the current row should be null
+        auto* __restrict null_map_data = null_map->get_data().data();
+        ColumnPtr argument_columns[argument_size]; //use to save nested_column 
if is nullable column
 
         for (size_t i = 0; i < argument_size; ++i) {
             block.get_by_position(filtered_args[i]).column =
@@ -93,40 +118,69 @@ public:
             }
         }
 
+        Block temporary_block {
+            ColumnsWithTypeAndName {
+                    block.get_by_position(filtered_args[0]),
+                    {nullptr, std::make_shared<DataTypeUInt8>(), ""}
+            }
+        };
+
         for (size_t i = 0; i < argument_size && remaining_rows; ++i) {
-            const ColumnsWithTypeAndName is_not_null_col 
{block.get_by_position(filtered_args[i])};
-            Block temporary_block(is_not_null_col);
-            temporary_block.insert(ColumnWithTypeAndName {nullptr, 
return_type, ""});
-            auto func_is_not_null = 
SimpleFunctionFactory::instance().get_function("is_not_null_pred", 
is_not_null_col, return_type);
+            temporary_block.get_by_position(0).column = 
block.get_by_position(filtered_args[i]).column;
             func_is_not_null->execute(context, temporary_block, {0}, 1, 
input_rows_count);
 
-            auto res_column = 
std::move(*temporary_block.get_by_position(1).column->convert_to_full_column_if_const()).mutate();
+            auto res_column = 
(*temporary_block.get_by_position(1).column->convert_to_full_column_if_const()).mutate();
             auto& res_map = 
assert_cast<ColumnVector<UInt8>*>(res_column.get())->get_data();
-            auto* res = res_map.data();
-
-            //TODO: if need to imporve performance in the feature, here it's 
could SIMD
-            for (size_t j = 0; j < input_rows_count && remaining_rows; ++j) {
-                if (res[j] && null_map_data[j]) {
-                    null_map_data[j] = 0;
-                    remaining_rows--;
-                    record_idx[j] = i;
+            auto* __restrict res = res_map.data();
+
+            // Here it's SIMD thought the compiler automatically
+            // true: res[j]==1 && null_map_data[j]==1, false: others
+            // if true: remaining_rows--; record_idx[j]=column_idx; 
null_map_data[j]=0, so the current row could fill result
+            for (size_t j = 0; j < input_rows_count; ++j) {
+                remaining_rows -= (res[j] & null_map_data[j]);
+                record_idx[j] += (res[j] & null_map_data[j]) * i;
+                null_map_data[j] -= (res[j] & null_map_data[j]);
+            }
+
+            if (remaining_rows == 0) {
+                //check whether all result data from the same column
+                size_t is_same_column_count = 0;
+                const auto data = record_idx[0];
+                for (size_t row = 0; row < input_rows_count; ++row) {
+                    is_same_column_count += (record_idx[row] == data);
+                }
+
+                if (is_same_column_count == input_rows_count) {
+                    if (result_type->is_nullable()) {
+                        block.get_by_position(result).column = 
make_nullable(argument_columns[i], false);
+                    } else {
+                        block.get_by_position(result).column = 
argument_columns[i];
+                    }
+                    return Status::OK();
                 }
             }
+
+            if (!is_string_result) {
+                //if not string type, could check one column firstly,
+                //and then fill the not null value in result column,
+                //this method may result in higher CPU cache
+                filled_result_column(result_type, result_column, 
argument_columns[i], null_map_data,
+                                     filled_flags.data(), input_rows_count);
+            }
         }
-        //TODO: According to the record results, fill in result one by one, 
-        //that's fill in result use different methods for different types,
-        //because now the string types does not support random position 
writing,
-        //But other type could, so could check one column, and fill the not 
null value in result column,
-        //and then next column, this method may result in higher CPU cache
-        for (int row = 0; row < input_rows_count; ++row) {
-            if (record_idx[row] == -1) {
-                result_column->insert_default();
-            } else {
-                
result_column->insert_from(*argument_columns[record_idx[row]].get(), row);
+
+        if (is_string_result) {
+            //if string type,  should according to the record results, fill in 
result one by one, 
+            for (size_t row = 0; row < input_rows_count; ++row) {
+                if (null_map_data[row]) { //should be null
+                    result_column->insert_default();
+                } else {
+                    
result_column->insert_from(*argument_columns[record_idx[row]].get(), row);
+                }
             }
         }
 
-        if (type->is_nullable()) {
+        if (result_type->is_nullable()) {
             block.replace_by_position(result, 
ColumnNullable::create(std::move(result_column), std::move(null_map)));
         } else {
             block.replace_by_position(result, std::move(result_column));
@@ -134,6 +188,76 @@ public:
 
         return Status::OK();
     }
+
+    template <typename ColumnType>
+    Status insert_result_data(MutableColumnPtr& result_column, ColumnPtr& 
argument_column,
+                              const UInt8* __restrict null_map_data, UInt8* 
__restrict filled_flag,
+                              const size_t input_rows_count) {
+        auto* __restrict result_raw_data =
+                
reinterpret_cast<ColumnType*>(result_column.get())->get_data().data();
+        auto* __restrict column_raw_data =
+                reinterpret_cast<const 
ColumnType*>(argument_column.get())->get_data().data();
+
+
+        // Here it's SIMD thought the compiler automatically also
+        // true: null_map_data[row]==0 && filled_idx[row]==0
+        // if true, could filled current row data into result column
+        for (size_t row = 0; row < input_rows_count; ++row) {
+            result_raw_data[row] += (!(null_map_data[row] | filled_flag[row])) 
* column_raw_data[row];
+            filled_flag[row] += (!(null_map_data[row] | filled_flag[row]));
+        }
+        return Status::OK();
+    }
+
+    //TODO: this function is same as case when, should be replaced by macro
+    Status filled_result_column(const DataTypePtr& data_type, 
MutableColumnPtr& result_column,
+                                ColumnPtr& argument_column, UInt8* __restrict 
null_map_data,
+                                UInt8* __restrict filled_flag, const size_t 
input_rows_count) {
+        WhichDataType which(data_type->is_nullable()
+                                    ? reinterpret_cast<const 
DataTypeNullable*>(data_type.get())
+                                              ->get_nested_type()
+                                    : data_type);
+        if (which.is_uint8()) {
+            return insert_result_data<ColumnUInt8>(result_column, 
argument_column, null_map_data,
+                                                   filled_flag, 
input_rows_count);
+        } else if (which.is_int16()) {
+            return insert_result_data<ColumnInt16>(result_column, 
argument_column, null_map_data,
+                                                   filled_flag, 
input_rows_count);
+        } else if (which.is_uint32()) {
+            return insert_result_data<ColumnUInt32>(result_column, 
argument_column, null_map_data,
+                                                    filled_flag, 
input_rows_count);
+        } else if (which.is_uint64()) {
+            return insert_result_data<ColumnUInt64>(result_column, 
argument_column, null_map_data,
+                                                    filled_flag, 
input_rows_count);
+        } else if (which.is_int8()) {
+            return insert_result_data<ColumnInt8>(result_column, 
argument_column, null_map_data,
+                                                  filled_flag, 
input_rows_count);
+        } else if (which.is_int16()) {
+            return insert_result_data<ColumnInt16>(result_column, 
argument_column, null_map_data,
+                                                   filled_flag, 
input_rows_count);
+        } else if (which.is_int32()) {
+            return insert_result_data<ColumnInt32>(result_column, 
argument_column, null_map_data,
+                                                   filled_flag, 
input_rows_count);
+        } else if (which.is_int64()) {
+            return insert_result_data<ColumnInt64>(result_column, 
argument_column, null_map_data,
+                                                   filled_flag, 
input_rows_count);
+        } else if (which.is_date_or_datetime()) {
+            return insert_result_data<ColumnVector<DateTime>>(
+                    result_column, argument_column, null_map_data, 
filled_flag, input_rows_count);
+        } else if (which.is_float32()) {
+            return insert_result_data<ColumnFloat32>(result_column, 
argument_column, null_map_data,
+                                                     filled_flag, 
input_rows_count);
+        } else if (which.is_float64()) {
+            return insert_result_data<ColumnFloat64>(result_column, 
argument_column, null_map_data,
+                                                     filled_flag, 
input_rows_count);
+        } else if (which.is_decimal()) {
+            return insert_result_data<ColumnDecimal<Decimal128>>(
+                    result_column, argument_column, null_map_data, 
filled_flag, input_rows_count);
+        } else {
+            return Status::NotSupported(fmt::format("Unexpected type {} of 
argument of function {}",
+                                                    data_type->get_name(), 
get_name()));
+        }
+    }
 };
 
 void register_function_coalesce(SimpleFunctionFactory& factory) {
diff --git a/be/src/vec/functions/is_not_null.cpp 
b/be/src/vec/functions/is_not_null.cpp
index 1bfeaa6..346dac1 100644
--- a/be/src/vec/functions/is_not_null.cpp
+++ b/be/src/vec/functions/is_not_null.cpp
@@ -52,8 +52,8 @@ public:
         if (auto* nullable = 
check_and_get_column<ColumnNullable>(*elem.column)) {
             /// Return the negated null map.
             auto res_column = ColumnUInt8::create(input_rows_count);
-            const auto& src_data = nullable->get_null_map_data();
-            auto& res_data = assert_cast<ColumnUInt8&>(*res_column).get_data();
+            const auto* __restrict src_data = 
nullable->get_null_map_data().data();
+            auto* __restrict res_data = 
assert_cast<ColumnUInt8&>(*res_column).get_data().data();
 
             for (size_t i = 0; i < input_rows_count; ++i) {
                 res_data[i] = !src_data[i];
diff --git a/be/src/vec/functions/simple_function_factory.h 
b/be/src/vec/functions/simple_function_factory.h
index 1d064bb..d757920 100644
--- a/be/src/vec/functions/simple_function_factory.h
+++ b/be/src/vec/functions/simple_function_factory.h
@@ -65,10 +65,11 @@ void register_function_like(SimpleFunctionFactory& factory);
 void register_function_regexp(SimpleFunctionFactory& factory);
 void register_function_random(SimpleFunctionFactory& factory);
 void register_function_coalesce(SimpleFunctionFactory& factory);
+
 class SimpleFunctionFactory {
     using Creator = std::function<FunctionBuilderPtr()>;
-    using FunctionCreators = std::unordered_map<std::string, Creator>;
-    using FunctionIsVariadic = std::unordered_set<std::string>;
+    using FunctionCreators = phmap::flat_hash_map<std::string, Creator>;
+    using FunctionIsVariadic = phmap::flat_hash_set<std::string>;
 
 public:
     void register_function(const std::string& name, Creator ptr) {
diff --git a/be/test/vec/function/function_string_test.cpp 
b/be/test/vec/function/function_string_test.cpp
index bd2f9a7..9f1f5d7 100644
--- a/be/test/vec/function/function_string_test.cpp
+++ b/be/test/vec/function/function_string_test.cpp
@@ -681,6 +681,48 @@ TEST(function_string_test, function_unhex_test) {
     vectorized::check_function<vectorized::DataTypeString, true>(func_name, 
input_types, data_set);
 }
 
+TEST(function_string_test, function_coalesce_test) {
+    std::string func_name = "coalesce";
+    {
+        std::vector<std::any> input_types = {vectorized::TypeIndex::Int32,
+                                             vectorized::TypeIndex::Int32,
+                                             vectorized::TypeIndex::Int32};
+        DataSet data_set = {{{Null(), Null(), (int32_t)1}, {(int32_t)1}},
+                            {{Null(), Null(), (int32_t)2}, {(int32_t)2}},
+                            {{Null(), Null(), (int32_t)3}, {(int32_t)3}},
+                            {{Null(), Null(), (int32_t)4}, {(int32_t)4}}};
+        vectorized::check_function<vectorized::DataTypeInt32, true>(func_name, 
input_types,
+                                                                    data_set);
+    }
+
+    {
+        std::vector<std::any> input_types = {vectorized::TypeIndex::String,
+                                             vectorized::TypeIndex::String,
+                                             vectorized::TypeIndex::Int32};
+        DataSet data_set = {
+                {{std::string("qwer"), Null(), (int32_t)1}, 
{std::string("qwer")}},
+                {{std::string("asdf"), Null(), (int32_t)2}, 
{std::string("asdf")}},
+                {{std::string("zxcv"), Null(), (int32_t)3}, 
{std::string("zxcv")}},
+                {{std::string("vbnm"), Null(), (int32_t)4}, 
{std::string("vbnm")}},
+        };
+        vectorized::check_function<vectorized::DataTypeString, 
true>(func_name, input_types,
+                                                                     data_set);
+    }
+
+    {
+        std::vector<std::any> input_types = {vectorized::TypeIndex::String,
+                                             vectorized::TypeIndex::String,
+                                             vectorized::TypeIndex::String};
+        DataSet data_set = {
+                {{Null(), std::string("abc"), std::string("hij")}, 
{std::string("abc")}},
+                {{Null(), std::string("def"), std::string("klm")}, 
{std::string("def")}},
+                {{Null(), std::string(""), std::string("xyz")}, 
{std::string("")}},
+                {{Null(), Null(), std::string("uvw")}, {std::string("uvw")}}};
+        vectorized::check_function<vectorized::DataTypeString, 
true>(func_name, input_types,
+                                                                     data_set);
+    }
+}
+
 } // namespace doris
 
 int main(int argc, char** argv) {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to