This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push: new f751ca4e045 [branch-2.1](functions) fix be crash for function random_bytes and mark_first/last_n (#36003) f751ca4e045 is described below commit f751ca4e045a069bca01be77b2bf3f6c4ba0e200 Author: zclllyybb <zhaochan...@selectdb.com> AuthorDate: Fri Jun 7 10:30:41 2024 +0800 [branch-2.1](functions) fix be crash for function random_bytes and mark_first/last_n (#36003) pick #35884 --- be/src/vec/functions/function_string.h | 31 +++++++++++----------- .../expressions/functions/scalar/MaskFirstN.java | 8 ++++++ .../expressions/functions/scalar/MaskLastN.java | 8 ++++++ .../correctness_p0/test_mask_function.groovy | 17 ++++++++++++ .../nereids_function_p0/scalar_function/R.groovy | 4 +++ 5 files changed, 52 insertions(+), 16 deletions(-) diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index fbaed751c7d..31c6cbb5ecb 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -792,10 +792,7 @@ public: Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { - DCHECK_GE(arguments.size(), 1); - DCHECK_LE(arguments.size(), 2); - - int n = -1; + int n = -1; // means unassigned auto res = ColumnString::create(); auto col = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const(); @@ -803,17 +800,20 @@ public: if (arguments.size() == 2) { const auto& col = *block.get_by_position(arguments[1]).column; + // the 2nd arg is const. checked in fe. + if (col.get_int(0) < 0) [[unlikely]] { + return Status::InvalidArgument( + "function {} only accept non-negative input for 2nd argument but got {}", + name, col.get_int(0)); + } n = col.get_int(0); - } else if (arguments.size() > 2) { - return Status::InvalidArgument( - fmt::format("too many arguments for function {}", get_name())); } - if (n == -1) { + if (n == -1) { // no 2nd arg, just mask all FunctionMask::vector_mask(source_column, *res, FunctionMask::DEFAULT_UPPER_MASK, FunctionMask::DEFAULT_LOWER_MASK, FunctionMask::DEFAULT_NUMBER_MASK); - } else if (n >= 0) { + } else { // n >= 0 vector(source_column, n, *res); } @@ -2901,19 +2901,18 @@ public: ColumnPtr argument_column = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const(); - const auto* length_col = check_and_get_column<ColumnInt32>(argument_column.get()); - - if (!length_col) { - return Status::InternalError("Not supported input argument type"); - } + const auto* length_col = assert_cast<const ColumnInt32*>(argument_column.get()); std::vector<uint8_t> random_bytes; std::random_device rd; std::mt19937 gen(rd()); for (size_t i = 0; i < input_rows_count; ++i) { - UInt64 length = length_col->get64(i); - random_bytes.resize(length); + if (length_col->get_element(i) < 0) [[unlikely]] { + return Status::InvalidArgument("argument {} of function {} at row {} was invalid.", + length_col->get_element(i), name, i); + } + random_bytes.resize(length_col->get_element(i)); std::uniform_int_distribution<uint8_t> distribution(0, 255); for (auto& byte : random_bytes) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java index 81a968067c2..33e19d468e8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskFirstN.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; @@ -65,6 +66,13 @@ public class MaskFirstN extends ScalarFunction implements ExplicitlyCastableSign return new MaskFirstN(children.get(0), children.get(1)); } + @Override + public void checkLegalityAfterRewrite() { + if (arity() == 2 && !child(1).isLiteral()) { + throw new AnalysisException("mask_first_n must accept literal for 2nd argument"); + } + } + @Override public List<FunctionSignature> getSignatures() { return SIGNATURES; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java index cb8246f04ab..eafb85ee89b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MaskLastN.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; @@ -65,6 +66,13 @@ public class MaskLastN extends ScalarFunction implements ExplicitlyCastableSigna return new MaskLastN(children.get(0), children.get(1)); } + @Override + public void checkLegalityAfterRewrite() { + if (arity() == 2 && !child(1).isLiteral()) { + throw new AnalysisException("mask_last_n must accept literal for 2nd argument"); + } + } + @Override public List<FunctionSignature> getSignatures() { return SIGNATURES; diff --git a/regression-test/suites/correctness_p0/test_mask_function.groovy b/regression-test/suites/correctness_p0/test_mask_function.groovy index b242e72eccc..b7717ab183c 100644 --- a/regression-test/suites/correctness_p0/test_mask_function.groovy +++ b/regression-test/suites/correctness_p0/test_mask_function.groovy @@ -75,4 +75,21 @@ suite("test_mask_function") { qt_select_digital_masking """ select digital_masking(13812345678); """ + + test { + sql """ select mask_last_n("12345", -100); """ + exception "function mask_last_n only accept non-negative input for 2nd argument but got -100" + } + test { + sql """ select mask_first_n("12345", -100); """ + exception "function mask_first_n only accept non-negative input for 2nd argument but got -100" + } + test { + sql """ select mask_last_n("12345", id) from table_mask_test; """ + exception "mask_last_n must accept literal for 2nd argument" + } + test { + sql """ select mask_first_n("12345", id) from table_mask_test; """ + exception "mask_first_n must accept literal for 2nd argument" + } } diff --git a/regression-test/suites/nereids_function_p0/scalar_function/R.groovy b/regression-test/suites/nereids_function_p0/scalar_function/R.groovy index fa58e6d0cb2..1110ed3a47a 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/R.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/R.groovy @@ -101,4 +101,8 @@ suite("nereids_scalar_fn_R") { qt_sql_rtrim_String_String_notnull "select rtrim(kstr, '1') from fn_test_not_nullable order by kstr" sql "SELECT random_bytes(7);" qt_sql_random_bytes "SELECT random_bytes(null);" + test { + sql " select random_bytes(-1); " + exception "argument -1 of function random_bytes at row 0 was invalid" + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org