This is an automated email from the ASF dual-hosted git repository.

gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new f8361470bbd [opt](function) Optimize the trim function for single-char 
inputs (#36497)
f8361470bbd is described below

commit f8361470bbd87148a46cdafc3b4834dae50df71b
Author: Mryange <59914473+mrya...@users.noreply.github.com>
AuthorDate: Fri Jun 28 15:57:54 2024 +0800

    [opt](function) Optimize the trim function for single-char inputs (#36497)
    
    before
    ```
    mysql [test]>select count(ltrim(str,"1")) from stringDb2;
    +------------------------+
    | count(ltrim(str, '1')) |
    +------------------------+
    |               64000000 |
    +------------------------+
    1 row in set (7.79 sec)
    ```
    
    now
    ```
    mysql [test]>select count(ltrim(str,"1")) from stringDb2;
    +------------------------+
    | count(ltrim(str, '1')) |
    +------------------------+
    |               64000000 |
    +------------------------+
    1 row in set (0.73 sec)
    ```
---
 be/src/util/simd/vstring_function.h                | 196 ++++++---------------
 be/src/vec/functions/function_string.cpp           |  54 +++---
 .../correctness/test_trim_new_parameters.groovy    |   3 +
 3 files changed, 92 insertions(+), 161 deletions(-)

diff --git a/be/src/util/simd/vstring_function.h 
b/be/src/util/simd/vstring_function.h
index dac964b1b94..4fff59a01df 100644
--- a/be/src/util/simd/vstring_function.h
+++ b/be/src/util/simd/vstring_function.h
@@ -17,6 +17,7 @@
 
 #pragma once
 
+#include <immintrin.h>
 #include <unistd.h>
 
 #include <array>
@@ -100,169 +101,86 @@ public:
     /// n equals to 16 chars length
     static constexpr auto REGISTER_SIZE = sizeof(__m128i);
 #endif
-public:
-    static StringRef rtrim(const StringRef& str) {
-        if (str.size == 0) {
-            return str;
-        }
-        auto begin = 0;
-        int64_t end = str.size - 1;
-#if defined(__SSE2__) || defined(__aarch64__)
-        char blank = ' ';
-        const auto pattern = _mm_set1_epi8(blank);
-        while (end - begin + 1 >= REGISTER_SIZE) {
-            const auto v_haystack = _mm_loadu_si128(
-                    reinterpret_cast<const __m128i*>(str.data + end + 1 - 
REGISTER_SIZE));
-            const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern);
-            const auto mask = _mm_movemask_epi8(v_against_pattern);
-            int offset = __builtin_clz(~(mask << REGISTER_SIZE));
-            /// means not found
-            if (offset == 0) {
-                return StringRef(str.data + begin, end - begin + 1);
-            } else {
-                end -= offset;
-            }
-        }
-#endif
-        while (end >= begin && str.data[end] == ' ') {
-            --end;
-        }
-        if (end < 0) {
-            return StringRef("");
-        }
-        return StringRef(str.data + begin, end - begin + 1);
-    }
-
-    static StringRef ltrim(const StringRef& str) {
-        if (str.size == 0) {
-            return str;
-        }
-        auto begin = 0;
-        auto end = str.size - 1;
-#if defined(__SSE2__) || defined(__aarch64__)
-        char blank = ' ';
-        const auto pattern = _mm_set1_epi8(blank);
-        while (end - begin + 1 >= REGISTER_SIZE) {
-            const auto v_haystack =
-                    _mm_loadu_si128(reinterpret_cast<const __m128i*>(str.data 
+ begin));
-            const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern);
-            const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 0xffff;
-            /// zero means not found
-            if (mask == 0) {
-                begin += REGISTER_SIZE;
-            } else {
-                const auto offset = __builtin_ctz(mask);
-                begin += offset;
-                return StringRef(str.data + begin, end - begin + 1);
-            }
-        }
-#endif
-        while (begin <= end && str.data[begin] == ' ') {
-            ++begin;
-        }
-        return StringRef(str.data + begin, end - begin + 1);
-    }
 
-    static StringRef trim(const StringRef& str) {
-        if (str.size == 0) {
-            return str;
+    template <bool trim_single>
+    static inline const unsigned char* rtrim(const unsigned char* begin, const 
unsigned char* end,
+                                             const StringRef& remove_str) {
+        if (remove_str.size == 0) {
+            return end;
         }
-        return rtrim(ltrim(str));
-    }
+        const auto* p = end;
 
-    static StringRef rtrim(const StringRef& str, const StringRef& rhs) {
-        if (str.size == 0 || rhs.size == 0) {
-            return str;
-        }
-        if (rhs.size == 1) {
-            auto begin = 0;
-            int64_t end = str.size - 1;
-            const char blank = rhs.data[0];
-#if defined(__SSE2__) || defined(__aarch64__)
-            const auto pattern = _mm_set1_epi8(blank);
-            while (end - begin + 1 >= REGISTER_SIZE) {
-                const auto v_haystack = _mm_loadu_si128(
-                        reinterpret_cast<const __m128i*>(str.data + end + 1 - 
REGISTER_SIZE));
-                const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, 
pattern);
-                const auto mask = _mm_movemask_epi8(v_against_pattern);
-                int offset = __builtin_clz(~(mask << REGISTER_SIZE));
-                /// means not found
-                if (offset == 0) {
-                    return StringRef(str.data + begin, end - begin + 1);
-                } else {
-                    end -= offset;
+        if constexpr (trim_single) {
+            const auto ch = remove_str.data[0];
+#if defined(__AVX2__) || defined(__aarch64__)
+            constexpr auto AVX2_BYTES = sizeof(__m256i);
+            const auto size = end - begin;
+            const auto* const avx2_begin = end - size / AVX2_BYTES * 
AVX2_BYTES;
+            const auto spaces = _mm256_set1_epi8(ch);
+            for (p = end - AVX2_BYTES; p >= avx2_begin; p -= AVX2_BYTES) {
+                uint32_t masks = _mm256_movemask_epi8(
+                        _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), 
spaces));
+                if ((~masks)) {
+                    break;
                 }
             }
+            p += AVX2_BYTES;
 #endif
-            while (end >= begin && str.data[end] == blank) {
-                --end;
-            }
-            if (end < 0) {
-                return StringRef("");
+            for (; (p - 1) >= begin && *(p - 1) == ch; p--) {
             }
-            return StringRef(str.data + begin, end - begin + 1);
+            return p;
         }
-        auto begin = 0;
-        auto end = str.size - 1;
-        const auto rhs_size = rhs.size;
-        while (end - begin + 1 >= rhs_size) {
-            if (memcmp(str.data + end - rhs_size + 1, rhs.data, rhs_size) == 
0) {
-                end -= rhs.size;
+
+        const auto remove_size = remove_str.size;
+        const auto* const remove_data = remove_str.data;
+        while (p - begin >= remove_size) {
+            if (memcmp(p - remove_size, remove_data, remove_size) == 0) {
+                p -= remove_str.size;
             } else {
                 break;
             }
         }
-        return StringRef(str.data + begin, end - begin + 1);
+        return p;
     }
 
-    static StringRef ltrim(const StringRef& str, const StringRef& rhs) {
-        if (str.size == 0 || rhs.size == 0) {
-            return str;
+    template <bool trim_single>
+    static inline const unsigned char* ltrim(const unsigned char* begin, const 
unsigned char* end,
+                                             const StringRef& remove_str) {
+        if (remove_str.size == 0) {
+            return begin;
         }
-        if (str.size == 1) {
-            auto begin = 0;
-            auto end = str.size - 1;
-            const char blank = rhs.data[0];
-#if defined(__SSE2__) || defined(__aarch64__)
-            const auto pattern = _mm_set1_epi8(blank);
-            while (end - begin + 1 >= REGISTER_SIZE) {
-                const auto v_haystack =
-                        _mm_loadu_si128(reinterpret_cast<const 
__m128i*>(str.data + begin));
-                const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, 
pattern);
-                const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 
0xffff;
-                /// zero means not found
-                if (mask == 0) {
-                    begin += REGISTER_SIZE;
-                } else {
-                    const auto offset = __builtin_ctz(mask);
-                    begin += offset;
-                    return StringRef(str.data + begin, end - begin + 1);
+        const auto* p = begin;
+
+        if constexpr (trim_single) {
+            const auto ch = remove_str.data[0];
+#if defined(__AVX2__) || defined(__aarch64__)
+            constexpr auto AVX2_BYTES = sizeof(__m256i);
+            const auto size = end - begin;
+            const auto* const avx2_end = begin + size / AVX2_BYTES * 
AVX2_BYTES;
+            const auto spaces = _mm256_set1_epi8(ch);
+            for (; p < avx2_end; p += AVX2_BYTES) {
+                uint32_t masks = _mm256_movemask_epi8(
+                        _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), 
spaces));
+                if ((~masks)) {
+                    break;
                 }
             }
 #endif
-            while (begin <= end && str.data[begin] == blank) {
-                ++begin;
+            for (; p < end && *p == ch; ++p) {
             }
-            return StringRef(str.data + begin, end - begin + 1);
+            return p;
         }
-        auto begin = 0;
-        auto end = str.size - 1;
-        const auto rhs_size = rhs.size;
-        while (end - begin + 1 >= rhs_size) {
-            if (memcmp(str.data + begin, rhs.data, rhs_size) == 0) {
-                begin += rhs.size;
+
+        const auto remove_size = remove_str.size;
+        const auto* const remove_data = remove_str.data;
+        while (end - p >= remove_size) {
+            if (memcmp(p, remove_data, remove_size) == 0) {
+                p += remove_str.size;
             } else {
                 break;
             }
         }
-        return StringRef(str.data + begin, end - begin + 1);
-    }
-
-    static StringRef trim(const StringRef& str, const StringRef& rhs) {
-        if (str.size == 0 || rhs.size == 0) {
-            return str;
-        }
-        return rtrim(ltrim(str, rhs), rhs);
+        return p;
     }
 
     // Gcc will do auto simd in this function
diff --git a/be/src/vec/functions/function_string.cpp 
b/be/src/vec/functions/function_string.cpp
index d4dae54612c..9216ad1b9c8 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -485,25 +485,29 @@ struct NameLTrim {
 struct NameRTrim {
     static constexpr auto name = "rtrim";
 };
-template <bool is_ltrim, bool is_rtrim>
+template <bool is_ltrim, bool is_rtrim, bool trim_single>
 struct TrimUtil {
     static Status vector(const ColumnString::Chars& str_data,
-                         const ColumnString::Offsets& str_offsets, const 
StringRef& rhs,
+                         const ColumnString::Offsets& str_offsets, const 
StringRef& remove_str,
                          ColumnString::Chars& res_data, ColumnString::Offsets& 
res_offsets) {
-        size_t offset_size = str_offsets.size();
-        res_offsets.resize(str_offsets.size());
+        const size_t offset_size = str_offsets.size();
+        res_offsets.resize(offset_size);
+        res_data.reserve(str_data.size());
         for (size_t i = 0; i < offset_size; ++i) {
-            const char* raw_str = reinterpret_cast<const 
char*>(&str_data[str_offsets[i - 1]]);
-            ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1];
-            StringRef str(raw_str, size);
+            const auto* str_begin = str_data.data() + str_offsets[i - 1];
+            const auto* str_end = str_data.data() + str_offsets[i];
+
             if constexpr (is_ltrim) {
-                str = simd::VStringFunctions::ltrim(str, rhs);
+                str_begin =
+                        simd::VStringFunctions::ltrim<trim_single>(str_begin, 
str_end, remove_str);
             }
             if constexpr (is_rtrim) {
-                str = simd::VStringFunctions::rtrim(str, rhs);
+                str_end =
+                        simd::VStringFunctions::rtrim<trim_single>(str_begin, 
str_end, remove_str);
             }
-            StringOP::push_value_string(std::string_view((char*)str.data, 
str.size), i, res_data,
-                                        res_offsets);
+
+            res_data.insert_assume_reserved(str_begin, str_end);
+            res_offsets[i] = res_data.size();
         }
         return Status::OK();
     }
@@ -521,9 +525,9 @@ struct Trim1Impl {
         if (const auto* col = assert_cast<const ColumnString*>(column.get())) {
             auto col_res = ColumnString::create();
             char blank[] = " ";
-            StringRef rhs(blank, 1);
-            RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim>::vector(
-                    col->get_chars(), col->get_offsets(), rhs, 
col_res->get_chars(),
+            const StringRef remove_str(blank, 1);
+            RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, true>::vector(
+                    col->get_chars(), col->get_offsets(), remove_str, 
col_res->get_chars(),
                     col_res->get_offsets())));
             block.replace_by_position(result, std::move(col_res));
         } else {
@@ -550,15 +554,21 @@ struct Trim2Impl {
         const auto& rcol =
                 assert_cast<const 
ColumnConst*>(block.get_by_position(arguments[1]).column.get())
                         ->get_data_column_ptr();
-        if (auto col = assert_cast<const ColumnString*>(column.get())) {
-            if (auto col_right = assert_cast<const ColumnString*>(rcol.get())) 
{
+        if (const auto* col = assert_cast<const ColumnString*>(column.get())) {
+            if (const auto* col_right = assert_cast<const 
ColumnString*>(rcol.get())) {
                 auto col_res = ColumnString::create();
-                const char* raw_rhs = reinterpret_cast<const 
char*>(&(col_right->get_chars()[0]));
-                ColumnString::Offset rhs_size = col_right->get_offsets()[0];
-                StringRef rhs(raw_rhs, rhs_size);
-                RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim>::vector(
-                        col->get_chars(), col->get_offsets(), rhs, 
col_res->get_chars(),
-                        col_res->get_offsets())));
+                const auto* remove_str_raw = col_right->get_chars().data();
+                const ColumnString::Offset remove_str_size = 
col_right->get_offsets()[0];
+                const StringRef remove_str(remove_str_raw, remove_str_size);
+                if (remove_str.size == 1) {
+                    RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, 
true>::vector(
+                            col->get_chars(), col->get_offsets(), remove_str, 
col_res->get_chars(),
+                            col_res->get_offsets())));
+                } else {
+                    RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, 
false>::vector(
+                            col->get_chars(), col->get_offsets(), remove_str, 
col_res->get_chars(),
+                            col_res->get_offsets())));
+                }
                 block.replace_by_position(result, std::move(col_res));
             } else {
                 return Status::RuntimeError("Illegal column {} of argument of 
function {}",
diff --git a/regression-test/suites/correctness/test_trim_new_parameters.groovy 
b/regression-test/suites/correctness/test_trim_new_parameters.groovy
index 3209eb7aae7..17ac4a0c65e 100644
--- a/regression-test/suites/correctness/test_trim_new_parameters.groovy
+++ b/regression-test/suites/correctness/test_trim_new_parameters.groovy
@@ -67,4 +67,7 @@ suite("test_trim_new_parameters") {
 
     rtrim = sql "select rtrim('bcTTTabcabc','abc')"
     assertEquals(rtrim[0][0], 'bcTTT')   
+
+    trim_one = sql "select 
trim('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa','a')"
+    assertEquals(trim_one[0][0], 
'baaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaab')  
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to