This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 48aaaa80058a86ce93eefed28803123dfb19a88b
Author: koarz <[email protected]>
AuthorDate: Sun Feb 4 22:19:30 2024 +0800

    [Enhancement](fuction) change function REPEAT nullable mode (#30743)
---
 be/src/vec/functions/function_string.cpp           |   1 +
 be/src/vec/functions/function_string.h             | 123 ++++++++++++++++++---
 .../trees/expressions/functions/scalar/Repeat.java |   4 +-
 gensrc/script/doris_builtins_functions.py          |   4 +-
 4 files changed, 112 insertions(+), 20 deletions(-)

diff --git a/be/src/vec/functions/function_string.cpp 
b/be/src/vec/functions/function_string.cpp
index 6965139a1c8..c5ce208d26f 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -1021,6 +1021,7 @@ void register_function_string(SimpleFunctionFactory& 
factory) {
     factory.register_alternative_function<FunctionLeftOld>();
     factory.register_alternative_function<FunctionRightOld>();
     factory.register_alternative_function<FunctionSubstringIndexOld>();
+    factory.register_alternative_function<FunctionStringRepeatOld>();
 
     factory.register_alias(FunctionLeft::name, "strleft");
     factory.register_alias(FunctionRight::name, "strright");
diff --git a/be/src/vec/functions/function_string.h 
b/be/src/vec/functions/function_string.h
index 6fc84074ddb..4794d28e0e0 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -1031,7 +1031,6 @@ public:
     DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
         return std::make_shared<DataTypeString>();
     }
-    bool use_default_implementation_for_nulls() const override { return true; }
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
@@ -1051,7 +1050,7 @@ public:
         for (int i = 0; i < argument_size; ++i) {
             argument_columns[i] =
                     
block.get_by_position(arguments[i]).column->convert_to_full_column_if_const();
-            auto col_str = assert_cast<const 
ColumnString*>(argument_columns[i].get());
+            const auto* col_str = assert_cast<const 
ColumnString*>(argument_columns[i].get());
             offsets_list[i] = &col_str->get_offsets();
             chars_list[i] = &col_str->get_chars();
         }
@@ -1084,8 +1083,8 @@ public:
         for (size_t i = 0; i < input_rows_count; ++i) {
             int current_length = 0;
             for (size_t j = 0; j < offsets_list.size(); ++j) {
-                auto& current_offsets = *offsets_list[j];
-                auto& current_chars = *chars_list[j];
+                const auto& current_offsets = *offsets_list[j];
+                const auto& current_chars = *chars_list[j];
 
                 int size = current_offsets[i] - current_offsets[i - 1];
                 if (size > 0) {
@@ -1431,6 +1430,103 @@ public:
     String get_name() const override { return name; }
     size_t get_number_of_arguments() const override { return 2; }
 
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        return std::make_shared<DataTypeString>();
+    }
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) const override 
{
+        DCHECK_EQ(arguments.size(), 2);
+        auto res = ColumnString::create();
+
+        ColumnPtr argument_ptr[2];
+        argument_ptr[0] =
+                
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
+        argument_ptr[1] = block.get_by_position(arguments[1]).column;
+
+        if (auto* col1 = check_and_get_column<ColumnString>(*argument_ptr[0])) 
{
+            if (auto* col2 = 
check_and_get_column<ColumnInt32>(*argument_ptr[1])) {
+                vector_vector(col1->get_chars(), col1->get_offsets(), 
col2->get_data(),
+                              res->get_chars(), res->get_offsets(),
+                              context->state()->repeat_max_num());
+                block.replace_by_position(result, std::move(res));
+                return Status::OK();
+            } else if (auto* col2_const = 
check_and_get_column<ColumnConst>(*argument_ptr[1])) {
+                
DCHECK(check_and_get_column<ColumnInt32>(col2_const->get_data_column()));
+                int repeat = 0;
+                repeat = std::min<int>(col2_const->get_int(0), 
context->state()->repeat_max_num());
+
+                if (repeat <= 0) {
+                    res->insert_many_defaults(input_rows_count);
+                } else {
+                    vector_const(col1->get_chars(), col1->get_offsets(), 
repeat, res->get_chars(),
+                                 res->get_offsets());
+                }
+                block.replace_by_position(result, std::move(res));
+                return Status::OK();
+            }
+        }
+
+        return Status::RuntimeError("repeat function get error param: {}, {}",
+                                    argument_ptr[0]->get_name(), 
argument_ptr[1]->get_name());
+    }
+
+    void vector_vector(const ColumnString::Chars& data, const 
ColumnString::Offsets& offsets,
+                       const ColumnInt32::Container& repeats, 
ColumnString::Chars& res_data,
+                       ColumnString::Offsets& res_offsets, const int 
repeat_max_num) const {
+        size_t input_row_size = offsets.size();
+
+        fmt::memory_buffer buffer;
+        res_offsets.resize(input_row_size);
+        for (ssize_t i = 0; i < input_row_size; ++i) {
+            buffer.clear();
+            const char* raw_str = reinterpret_cast<const 
char*>(&data[offsets[i - 1]]);
+            size_t size = offsets[i] - offsets[i - 1];
+            int repeat = 0;
+            repeat = std::min<int>(repeats[i], repeat_max_num);
+
+            if (repeat <= 0) {
+                StringOP::push_empty_string(i, res_data, res_offsets);
+            } else {
+                for (int j = 0; j < repeat; ++j) {
+                    buffer.append(raw_str, raw_str + size);
+                }
+                StringOP::push_value_string(std::string_view(buffer.data(), 
buffer.size()), i,
+                                            res_data, res_offsets);
+            }
+        }
+    }
+
+    // TODO: 1. use pmr::vector<char> replace fmt_buffer may speed up the code
+    //       2. abstract the `vector_vector` and `vector_const`
+    //       3. rethink we should use `DEFAULT_MAX_STRING_SIZE` to bigger here
+    void vector_const(const ColumnString::Chars& data, const 
ColumnString::Offsets& offsets,
+                      int repeat, ColumnString::Chars& res_data,
+                      ColumnString::Offsets& res_offsets) const {
+        size_t input_row_size = offsets.size();
+
+        fmt::memory_buffer buffer;
+        res_offsets.resize(input_row_size);
+        for (ssize_t i = 0; i < input_row_size; ++i) {
+            buffer.clear();
+            const char* raw_str = reinterpret_cast<const 
char*>(&data[offsets[i - 1]]);
+            size_t size = offsets[i] - offsets[i - 1];
+
+            for (int j = 0; j < repeat; ++j) {
+                buffer.append(raw_str, raw_str + size);
+            }
+            StringOP::push_value_string(std::string_view(buffer.data(), 
buffer.size()), i, res_data,
+                                        res_offsets);
+        }
+    }
+};
+
+class FunctionStringRepeatOld : public IFunction {
+public:
+    static constexpr auto name = "repeat";
+    static FunctionPtr create() { return 
std::make_shared<FunctionStringRepeatOld>(); }
+    String get_name() const override { return name; }
+    size_t get_number_of_arguments() const override { return 2; }
+
     DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
         return make_nullable(std::make_shared<DataTypeString>());
     }
@@ -1545,7 +1641,6 @@ public:
     DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
         return make_nullable(std::make_shared<DataTypeString>());
     }
-    bool use_default_implementation_for_nulls() const override { return true; }
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
@@ -1688,8 +1783,6 @@ public:
         return make_nullable(std::make_shared<DataTypeString>());
     }
 
-    bool use_default_implementation_for_nulls() const override { return true; }
-
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
         DCHECK_EQ(arguments.size(), 3);
@@ -2003,7 +2096,7 @@ public:
 class FunctionSubstringIndexOld : public IFunction {
 public:
     static constexpr auto name = "substring_index";
-    static FunctionPtr create() { return 
std::make_shared<FunctionSubstringIndex>(); }
+    static FunctionPtr create() { return 
std::make_shared<FunctionSubstringIndexOld>(); }
     String get_name() const override { return name; }
     size_t get_number_of_arguments() const override { return 3; }
 
@@ -2160,7 +2253,6 @@ public:
         return Status::OK();
     }
 };
-
 class FunctionSplitByString : public IFunction {
 public:
     static constexpr auto name = "split_by_string";
@@ -2205,17 +2297,17 @@ public:
         dest_offsets.reserve(0);
 
         NullMapType* dest_nested_null_map = nullptr;
-        ColumnNullable* dest_nullable_col = 
reinterpret_cast<ColumnNullable*>(dest_nested_column);
+        auto* dest_nullable_col = 
reinterpret_cast<ColumnNullable*>(dest_nested_column);
         dest_nested_column = dest_nullable_col->get_nested_column_ptr();
         dest_nested_null_map = 
&dest_nullable_col->get_null_map_column().get_data();
 
-        auto col_left = check_and_get_column<ColumnString>(src_column.get());
+        const auto* col_left = 
check_and_get_column<ColumnString>(src_column.get());
         if (!col_left) {
             return Status::InternalError("Left operator of function {} can not 
be {}", get_name(),
                                          src_column_type->get_name());
         }
 
-        auto col_right = 
check_and_get_column<ColumnString>(right_column.get());
+        const auto* col_right = 
check_and_get_column<ColumnString>(right_column.get());
         if (!col_right) {
             return Status::InternalError("Right operator of function {} can 
not be {}", get_name(),
                                          right_column_type->get_name());
@@ -2245,7 +2337,7 @@ private:
                                      const StringRef& delimiter_ref, IColumn& 
dest_nested_column,
                                      ColumnArray::Offsets64& dest_offsets,
                                      NullMapType* dest_nested_null_map) const {
-        ColumnString& dest_column_string = 
reinterpret_cast<ColumnString&>(dest_nested_column);
+        auto& dest_column_string = 
reinterpret_cast<ColumnString&>(dest_nested_column);
         ColumnString::Chars& column_string_chars = 
dest_column_string.get_chars();
         ColumnString::Offsets& column_string_offsets = 
dest_column_string.get_offsets();
         column_string_chars.reserve(0);
@@ -2312,7 +2404,7 @@ private:
                          const ColumnString& delimiter_column, IColumn& 
dest_nested_column,
                          ColumnArray::Offsets64& dest_offsets,
                          NullMapType* dest_nested_null_map) const {
-        ColumnString& dest_column_string = 
reinterpret_cast<ColumnString&>(dest_nested_column);
+        auto& dest_column_string = 
reinterpret_cast<ColumnString&>(dest_nested_column);
         ColumnString::Chars& column_string_chars = 
dest_column_string.get_chars();
         ColumnString::Offsets& column_string_offsets = 
dest_column_string.get_offsets();
         column_string_chars.reserve(0);
@@ -2369,7 +2461,7 @@ private:
                                       IColumn& dest_nested_column,
                                       ColumnArray::Offsets64& dest_offsets,
                                       NullMapType* dest_nested_null_map) const 
{
-        ColumnString& dest_column_string = 
reinterpret_cast<ColumnString&>(dest_nested_column);
+        auto& dest_column_string = 
reinterpret_cast<ColumnString&>(dest_nested_column);
         ColumnString::Chars& column_string_chars = 
dest_column_string.get_chars();
         ColumnString::Offsets& column_string_offsets = 
dest_column_string.get_offsets();
         column_string_chars.reserve(0);
@@ -2659,7 +2751,6 @@ public:
     DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
         return make_nullable(std::make_shared<DataTypeString>());
     }
-    bool use_default_implementation_for_nulls() const override { return true; }
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Repeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Repeat.java
index b85a812197f..918443e8161 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Repeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Repeat.java
@@ -19,8 +19,8 @@ package 
org.apache.doris.nereids.trees.expressions.functions.scalar;
 
 import org.apache.doris.catalog.FunctionSignature;
 import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
 import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
 import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.IntegerType;
@@ -35,7 +35,7 @@ import java.util.List;
  * ScalarFunction 'repeat'. This class is generated by GenerateFunction.
  */
 public class Repeat extends ScalarFunction
-        implements BinaryExpression, ExplicitlyCastableSignature, 
AlwaysNullable {
+        implements BinaryExpression, ExplicitlyCastableSignature, 
PropagateNullable {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             
FunctionSignature.ret(StringType.INSTANCE).args(StringType.INSTANCE, 
IntegerType.INSTANCE)
diff --git a/gensrc/script/doris_builtins_functions.py 
b/gensrc/script/doris_builtins_functions.py
index 722715da2c4..bd52ffe789d 100644
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -1564,7 +1564,7 @@ visible_functions = {
         [['null_or_empty'], 'BOOLEAN', ['VARCHAR'], 'ALWAYS_NOT_NULLABLE'],
         [['not_null_or_empty'], 'BOOLEAN', ['VARCHAR'], 'ALWAYS_NOT_NULLABLE'],
         [['space'], 'VARCHAR', ['INT'], ''],
-        [['repeat'], 'VARCHAR', ['VARCHAR', 'INT'], 'ALWAYS_NULLABLE'],
+        [['repeat'], 'VARCHAR', ['VARCHAR', 'INT'], 'DEPEND_ON_ARGUMENT'],
         [['lpad'], 'VARCHAR', ['VARCHAR', 'INT', 'VARCHAR'], 
'ALWAYS_NULLABLE'],
         [['rpad'], 'VARCHAR', ['VARCHAR', 'INT', 'VARCHAR'], 
'ALWAYS_NULLABLE'],
         [['append_trailing_char_if_absent'], 'VARCHAR', ['VARCHAR', 
'VARCHAR'], 'ALWAYS_NULLABLE'],
@@ -1624,7 +1624,7 @@ visible_functions = {
         [['null_or_empty'], 'BOOLEAN', ['STRING'], 'ALWAYS_NOT_NULLABLE'],
         [['not_null_or_empty'], 'BOOLEAN', ['STRING'], 'ALWAYS_NOT_NULLABLE'],
         [['space'], 'STRING', ['INT'], ''],
-        [['repeat'], 'STRING', ['STRING', 'INT'], 'ALWAYS_NULLABLE'],
+        [['repeat'], 'STRING', ['STRING', 'INT'], 'DEPEND_ON_ARGUMENT'],
         [['lpad'], 'STRING', ['STRING', 'INT', 'STRING'], 'ALWAYS_NULLABLE'],
         [['rpad'], 'STRING', ['STRING', 'INT', 'STRING'], 'ALWAYS_NULLABLE'],
         [['append_trailing_char_if_absent'], 'STRING', ['STRING', 'STRING'], 
'ALWAYS_NULLABLE'],


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to