This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new f751ca4e045 [branch-2.1](functions) fix be crash for function 
random_bytes and mark_first/last_n (#36003)
f751ca4e045 is described below

commit f751ca4e045a069bca01be77b2bf3f6c4ba0e200
Author: zclllyybb <zhaochan...@selectdb.com>
AuthorDate: Fri Jun 7 10:30:41 2024 +0800

    [branch-2.1](functions) fix be crash for function random_bytes and 
mark_first/last_n (#36003)
    
    pick #35884
---
 be/src/vec/functions/function_string.h             | 31 +++++++++++-----------
 .../expressions/functions/scalar/MaskFirstN.java   |  8 ++++++
 .../expressions/functions/scalar/MaskLastN.java    |  8 ++++++
 .../correctness_p0/test_mask_function.groovy       | 17 ++++++++++++
 .../nereids_function_p0/scalar_function/R.groovy   |  4 +++
 5 files changed, 52 insertions(+), 16 deletions(-)

diff --git a/be/src/vec/functions/function_string.h 
b/be/src/vec/functions/function_string.h
index fbaed751c7d..31c6cbb5ecb 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -792,10 +792,7 @@ public:
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
-        DCHECK_GE(arguments.size(), 1);
-        DCHECK_LE(arguments.size(), 2);
-
-        int n = -1;
+        int n = -1; // means unassigned
 
         auto res = ColumnString::create();
         auto col = 
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
@@ -803,17 +800,20 @@ public:
 
         if (arguments.size() == 2) {
             const auto& col = *block.get_by_position(arguments[1]).column;
+            // the 2nd arg is const. checked in fe.
+            if (col.get_int(0) < 0) [[unlikely]] {
+                return Status::InvalidArgument(
+                        "function {} only accept non-negative input for 2nd 
argument but got {}",
+                        name, col.get_int(0));
+            }
             n = col.get_int(0);
-        } else if (arguments.size() > 2) {
-            return Status::InvalidArgument(
-                    fmt::format("too many arguments for function {}", 
get_name()));
         }
 
-        if (n == -1) {
+        if (n == -1) { // no 2nd arg, just mask all
             FunctionMask::vector_mask(source_column, *res, 
FunctionMask::DEFAULT_UPPER_MASK,
                                       FunctionMask::DEFAULT_LOWER_MASK,
                                       FunctionMask::DEFAULT_NUMBER_MASK);
-        } else if (n >= 0) {
+        } else { // n >= 0
             vector(source_column, n, *res);
         }
 
@@ -2901,19 +2901,18 @@ public:
 
         ColumnPtr argument_column =
                 
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
-        const auto* length_col = 
check_and_get_column<ColumnInt32>(argument_column.get());
-
-        if (!length_col) {
-            return Status::InternalError("Not supported input argument type");
-        }
+        const auto* length_col = assert_cast<const 
ColumnInt32*>(argument_column.get());
 
         std::vector<uint8_t> random_bytes;
         std::random_device rd;
         std::mt19937 gen(rd());
 
         for (size_t i = 0; i < input_rows_count; ++i) {
-            UInt64 length = length_col->get64(i);
-            random_bytes.resize(length);
+            if (length_col->get_element(i) < 0) [[unlikely]] {
+                return Status::InvalidArgument("argument {} of function {} at 
row {} was invalid.",
+                                               length_col->get_element(i), 
name, i);
+            }
+            random_bytes.resize(length_col->get_element(i));
 
             std::uniform_int_distribution<uint8_t> distribution(0, 255);
             for (auto& byte : random_bytes) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java
index 81a968067c2..33e19d468e8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.trees.expressions.functions.scalar;
 
 import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
 import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
@@ -65,6 +66,13 @@ public class MaskFirstN extends ScalarFunction implements 
ExplicitlyCastableSign
         return new MaskFirstN(children.get(0), children.get(1));
     }
 
+    @Override
+    public void checkLegalityAfterRewrite() {
+        if (arity() == 2 && !child(1).isLiteral()) {
+            throw new AnalysisException("mask_first_n must accept literal for 
2nd argument");
+        }
+    }
+
     @Override
     public List<FunctionSignature> getSignatures() {
         return SIGNATURES;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java
index cb8246f04ab..eafb85ee89b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.trees.expressions.functions.scalar;
 
 import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
 import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
@@ -65,6 +66,13 @@ public class MaskLastN extends ScalarFunction implements 
ExplicitlyCastableSigna
         return new MaskLastN(children.get(0), children.get(1));
     }
 
+    @Override
+    public void checkLegalityAfterRewrite() {
+        if (arity() == 2 && !child(1).isLiteral()) {
+            throw new AnalysisException("mask_last_n must accept literal for 
2nd argument");
+        }
+    }
+
     @Override
     public List<FunctionSignature> getSignatures() {
         return SIGNATURES;
diff --git a/regression-test/suites/correctness_p0/test_mask_function.groovy 
b/regression-test/suites/correctness_p0/test_mask_function.groovy
index b242e72eccc..b7717ab183c 100644
--- a/regression-test/suites/correctness_p0/test_mask_function.groovy
+++ b/regression-test/suites/correctness_p0/test_mask_function.groovy
@@ -75,4 +75,21 @@ suite("test_mask_function") {
     qt_select_digital_masking """
         select digital_masking(13812345678);
     """
+
+    test {
+        sql """ select mask_last_n("12345", -100); """
+        exception "function mask_last_n only accept non-negative input for 2nd 
argument but got -100"
+    }
+    test {
+        sql """ select mask_first_n("12345", -100); """
+        exception "function mask_first_n only accept non-negative input for 
2nd argument but got -100"
+    }
+    test {
+        sql """ select mask_last_n("12345", id) from table_mask_test; """
+        exception "mask_last_n must accept literal for 2nd argument"
+    }
+    test {
+        sql """ select mask_first_n("12345", id) from table_mask_test; """
+        exception "mask_first_n must accept literal for 2nd argument"
+    }
 }
diff --git 
a/regression-test/suites/nereids_function_p0/scalar_function/R.groovy 
b/regression-test/suites/nereids_function_p0/scalar_function/R.groovy
index fa58e6d0cb2..1110ed3a47a 100644
--- a/regression-test/suites/nereids_function_p0/scalar_function/R.groovy
+++ b/regression-test/suites/nereids_function_p0/scalar_function/R.groovy
@@ -101,4 +101,8 @@ suite("nereids_scalar_fn_R") {
        qt_sql_rtrim_String_String_notnull "select rtrim(kstr, '1') from 
fn_test_not_nullable order by kstr"
        sql "SELECT random_bytes(7);"
        qt_sql_random_bytes "SELECT random_bytes(null);"
+       test {
+               sql " select random_bytes(-1); "
+               exception "argument -1 of function random_bytes at row 0 was 
invalid"
+       }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to