This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new d75300f166a [fix](hash join) fix stack overflow caused by evaluate 
case expr on huge build block (#28851)
d75300f166a is described below

commit d75300f166acdd4c9ca0a5d662472418d4272e95
Author: TengJianPing <18241664+jackte...@users.noreply.github.com>
AuthorDate: Fri Dec 22 15:45:12 2023 +0800

    [fix](hash join) fix stack overflow caused by evaluate case expr on huge 
build block (#28851)
---
 be/src/pipeline/exec/hashjoin_build_sink.cpp      | 18 +++++++++-----
 be/src/pipeline/exec/hashjoin_build_sink.h        |  1 +
 be/src/vec/columns/column_vector.cpp              |  3 ++-
 be/src/vec/exec/join/vhash_join_node.cpp          | 16 +++++++-----
 be/src/vec/exec/join/vhash_join_node.h            |  1 +
 be/src/vec/functions/function_binary_arithmetic.h |  5 ++--
 be/src/vec/functions/function_case.h              | 21 ++++++++--------
 be/src/vec/functions/function_string.cpp          | 30 ++++++++++++++++++++---
 be/src/vec/functions/multiply.cpp                 |  3 ++-
 9 files changed, 69 insertions(+), 29 deletions(-)

diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index c3238df035e..8cd0376a957 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -230,8 +230,6 @@ Status 
HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
     vectorized::ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size());
 
     vectorized::ColumnUInt8::MutablePtr null_map_val;
-    std::vector<int> res_col_ids(_build_expr_ctxs.size());
-    RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, 
*_build_expr_call_timer, res_col_ids));
     if (p._join_op == TJoinOp::LEFT_OUTER_JOIN || p._join_op == 
TJoinOp::FULL_OUTER_JOIN) {
         _convert_block_to_null(block);
         // first row is mocked
@@ -247,7 +245,7 @@ Status 
HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
     //  so we have to initialize this flag by the first build block.
     if (!_has_set_need_null_map_for_build) {
         _has_set_need_null_map_for_build = true;
-        _set_build_ignore_flag(block, res_col_ids);
+        _set_build_ignore_flag(block, _build_col_ids);
     }
     if (p._short_circuit_for_null_in_build_side || _build_side_ignore_null) {
         null_map_val = vectorized::ColumnUInt8::create();
@@ -255,7 +253,7 @@ Status 
HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
     }
 
     // Get the key column that needs to be built
-    Status st = _extract_join_column(block, null_map_val, raw_ptrs, 
res_col_ids);
+    Status st = _extract_join_column(block, null_map_val, raw_ptrs, 
_build_col_ids);
 
     st = std::visit(
             Overload {[&](std::monostate& arg, auto join_op, auto 
has_null_value,
@@ -458,13 +456,21 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState* 
state, vectorized::Block*
         if (local_state._build_side_mutable_block.empty()) {
             auto tmp_build_block = 
vectorized::VectorizedUtils::create_empty_columnswithtypename(
                     _child_x->row_desc());
+            tmp_build_block = *(tmp_build_block.create_same_struct_block(1, 
false));
+            local_state._build_col_ids.resize(_build_expr_ctxs.size());
+            RETURN_IF_ERROR(local_state._do_evaluate(tmp_build_block, 
local_state._build_expr_ctxs,
+                                                     
*local_state._build_expr_call_timer,
+                                                     
local_state._build_col_ids));
             local_state._build_side_mutable_block =
                     
vectorized::MutableBlock::build_mutable_block(&tmp_build_block);
-            RETURN_IF_ERROR(local_state._build_side_mutable_block.merge(
-                    *(tmp_build_block.create_same_struct_block(1, false))));
         }
 
         if (in_block->rows() != 0) {
+            std::vector<int> res_col_ids(_build_expr_ctxs.size());
+            RETURN_IF_ERROR(local_state._do_evaluate(*in_block, 
local_state._build_expr_ctxs,
+                                                     
*local_state._build_expr_call_timer,
+                                                     res_col_ids));
+
             SCOPED_TIMER(local_state._build_side_merge_block_timer);
             
RETURN_IF_ERROR(local_state._build_side_mutable_block.merge(*in_block));
             if (local_state._build_side_mutable_block.rows() >
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h 
b/be/src/pipeline/exec/hashjoin_build_sink.h
index b2fbec8575e..ecf0a4a3122 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.h
+++ b/be/src/pipeline/exec/hashjoin_build_sink.h
@@ -116,6 +116,7 @@ protected:
     bool _build_side_ignore_null = false;
     std::unordered_set<const vectorized::Block*> _inserted_blocks;
     std::shared_ptr<SharedHashTableDependency> _shared_hash_table_dependency;
+    std::vector<int> _build_col_ids;
 
     RuntimeProfile::Counter* _build_table_timer = nullptr;
     RuntimeProfile::Counter* _build_expr_call_timer = nullptr;
diff --git a/be/src/vec/columns/column_vector.cpp 
b/be/src/vec/columns/column_vector.cpp
index 45d9e8f70b0..ca8db58fc98 100644
--- a/be/src/vec/columns/column_vector.cpp
+++ b/be/src/vec/columns/column_vector.cpp
@@ -524,7 +524,8 @@ ColumnPtr ColumnVector<T>::replicate(const 
IColumn::Offsets& offsets) const {
     res_data.reserve(offsets.back());
 
     // vectorized this code to speed up
-    IColumn::Offset counts[size];
+    auto counts_uptr = std::unique_ptr<IColumn::Offset[]>(new 
IColumn::Offset[size]);
+    IColumn::Offset* counts = counts_uptr.get();
     for (ssize_t i = 0; i < size; ++i) {
         counts[i] = offsets[i] - offsets[i - 1];
     }
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp 
b/be/src/vec/exec/join/vhash_join_node.cpp
index 1202228ae85..30f2450f458 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -725,12 +725,18 @@ Status HashJoinNode::sink(doris::RuntimeState* state, 
vectorized::Block* in_bloc
         if (_build_side_mutable_block.empty()) {
             auto tmp_build_block =
                     
VectorizedUtils::create_empty_columnswithtypename(child(1)->row_desc());
+            tmp_build_block = *(tmp_build_block.create_same_struct_block(1, 
false));
+            _build_col_ids.resize(_build_expr_ctxs.size());
+            RETURN_IF_ERROR(_do_evaluate(tmp_build_block, _build_expr_ctxs, 
*_build_expr_call_timer,
+                                         _build_col_ids));
             _build_side_mutable_block = 
MutableBlock::build_mutable_block(&tmp_build_block);
-            RETURN_IF_ERROR(_build_side_mutable_block.merge(
-                    *(tmp_build_block.create_same_struct_block(1, false))));
         }
 
         if (in_block->rows() != 0) {
+            std::vector<int> res_col_ids(_build_expr_ctxs.size());
+            RETURN_IF_ERROR(_do_evaluate(*in_block, _build_expr_ctxs, 
*_build_expr_call_timer,
+                                         res_col_ids));
+
             SCOPED_TIMER(_build_side_merge_block_timer);
             RETURN_IF_ERROR(_build_side_mutable_block.merge(*in_block));
             if (_build_side_mutable_block.rows() > JOIN_BUILD_SIZE_LIMIT) {
@@ -952,8 +958,6 @@ Status HashJoinNode::_process_build_block(RuntimeState* 
state, Block& block) {
     ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size());
 
     ColumnUInt8::MutablePtr null_map_val;
-    std::vector<int> res_col_ids(_build_expr_ctxs.size());
-    RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, 
*_build_expr_call_timer, res_col_ids));
     if (_join_op == TJoinOp::LEFT_OUTER_JOIN || _join_op == 
TJoinOp::FULL_OUTER_JOIN) {
         _convert_block_to_null(block);
         // first row is mocked
@@ -969,7 +973,7 @@ Status HashJoinNode::_process_build_block(RuntimeState* 
state, Block& block) {
     //  so we have to initialize this flag by the first build block.
     if (!_has_set_need_null_map_for_build) {
         _has_set_need_null_map_for_build = true;
-        _set_build_ignore_flag(block, res_col_ids);
+        _set_build_ignore_flag(block, _build_col_ids);
     }
     if (_short_circuit_for_null_in_build_side || _build_side_ignore_null) {
         null_map_val = ColumnUInt8::create();
@@ -977,7 +981,7 @@ Status HashJoinNode::_process_build_block(RuntimeState* 
state, Block& block) {
     }
 
     // Get the key column that needs to be built
-    Status st = _extract_join_column<true>(block, null_map_val, raw_ptrs, 
res_col_ids);
+    Status st = _extract_join_column<true>(block, null_map_val, raw_ptrs, 
_build_col_ids);
 
     st = std::visit(
             Overload {[&](std::monostate& arg, auto join_op, auto 
has_null_value,
diff --git a/be/src/vec/exec/join/vhash_join_node.h 
b/be/src/vec/exec/join/vhash_join_node.h
index 64f07af6504..8304eedb290 100644
--- a/be/src/vec/exec/join/vhash_join_node.h
+++ b/be/src/vec/exec/join/vhash_join_node.h
@@ -451,6 +451,7 @@ private:
 
     std::vector<IRuntimeFilter*> _runtime_filters;
     std::atomic_bool _probe_open_finish = false;
+    std::vector<int> _build_col_ids;
 };
 } // namespace vectorized
 } // namespace doris
diff --git a/be/src/vec/functions/function_binary_arithmetic.h 
b/be/src/vec/functions/function_binary_arithmetic.h
index 30ede75ea17..4b69561b14e 100644
--- a/be/src/vec/functions/function_binary_arithmetic.h
+++ b/be/src/vec/functions/function_binary_arithmetic.h
@@ -265,7 +265,8 @@ private:
                     make_bool_variant(need_adjust_scale && check_overflow));
 
             if (OpTraits::is_multiply && need_adjust_scale && !check_overflow) 
{
-                int8_t sig[size];
+                auto sig_uptr = std::unique_ptr<int8_t[]>(new int8_t[size]);
+                int8_t* sig = sig_uptr.get();
                 for (size_t i = 0; i < size; i++) {
                     sig[i] = sgn(c[i].value);
                 }
@@ -917,7 +918,7 @@ public:
                     if constexpr (!std::is_same_v<ResultDataType, 
InvalidType>) {
                         need_replace_null_data_to_default_ =
                                 IsDataTypeDecimal<ResultDataType> ||
-                                (name == "pow" &&
+                                (get_name() == "pow" &&
                                  std::is_floating_point_v<typename 
ResultDataType::FieldType>);
                         if constexpr (IsDataTypeDecimal<LeftDataType> &&
                                       IsDataTypeDecimal<RightDataType>) {
diff --git a/be/src/vec/functions/function_case.h 
b/be/src/vec/functions/function_case.h
index 2ecc6bd186d..26e12e7bd13 100644
--- a/be/src/vec/functions/function_case.h
+++ b/be/src/vec/functions/function_case.h
@@ -159,9 +159,9 @@ public:
         int rows_count = column_holder.rows_count;
 
         // `then` data index corresponding to each row of results, 0 
represents `else`.
-        int then_idx[rows_count];
-        int* __restrict then_idx_ptr = then_idx;
-        memset(then_idx_ptr, 0, sizeof(then_idx));
+        auto then_idx_uptr = std::unique_ptr<int[]>(new int[rows_count]);
+        int* __restrict then_idx_ptr = then_idx_uptr.get();
+        memset(then_idx_ptr, 0, rows_count * sizeof(int));
 
         for (int row_idx = 0; row_idx < column_holder.rows_count; row_idx++) {
             for (int i = 1; i < column_holder.pair_count; i++) {
@@ -189,7 +189,7 @@ public:
         }
 
         auto result_column_ptr = data_type->create_column();
-        update_result_normal<int, ColumnType, then_null>(result_column_ptr, 
then_idx,
+        update_result_normal<int, ColumnType, then_null>(result_column_ptr, 
then_idx_ptr,
                                                          column_holder);
         block.replace_by_position(result, std::move(result_column_ptr));
         return Status::OK();
@@ -206,9 +206,9 @@ public:
         int rows_count = column_holder.rows_count;
 
         // `then` data index corresponding to each row of results, 0 
represents `else`.
-        uint8_t then_idx[rows_count];
-        uint8_t* __restrict then_idx_ptr = then_idx;
-        memset(then_idx_ptr, 0, sizeof(then_idx));
+        auto then_idx_uptr = std::unique_ptr<uint8_t[]>(new 
uint8_t[rows_count]);
+        uint8_t* __restrict then_idx_ptr = then_idx_uptr.get();
+        memset(then_idx_ptr, 0, rows_count);
 
         auto case_column_ptr = column_holder.when_ptrs[0].value_or(nullptr);
 
@@ -245,13 +245,13 @@ public:
             }
         }
 
-        return execute_update_result<ColumnType, then_null>(data_type, result, 
block, then_idx,
+        return execute_update_result<ColumnType, then_null>(data_type, result, 
block, then_idx_ptr,
                                                             column_holder);
     }
 
     template <typename ColumnType, bool then_null>
     Status execute_update_result(const DataTypePtr& data_type, size_t result, 
Block& block,
-                                 uint8* then_idx, CaseWhenColumnHolder& 
column_holder) const {
+                                 const uint8* then_idx, CaseWhenColumnHolder& 
column_holder) const {
         auto result_column_ptr = data_type->create_column();
 
         if constexpr (std::is_same_v<ColumnType, ColumnString> ||
@@ -282,7 +282,8 @@ public:
     }
 
     template <typename IndexType, typename ColumnType, bool then_null>
-    void update_result_normal(MutableColumnPtr& result_column_ptr, IndexType* 
then_idx,
+    void update_result_normal(MutableColumnPtr& result_column_ptr,
+                              const IndexType* __restrict then_idx,
                               CaseWhenColumnHolder& column_holder) const {
         std::vector<uint8_t> is_consts(column_holder.then_ptrs.size());
         std::vector<ColumnPtr> raw_columns(column_holder.then_ptrs.size());
diff --git a/be/src/vec/functions/function_string.cpp 
b/be/src/vec/functions/function_string.cpp
index 6179d64e47d..7b4e043efe6 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -582,6 +582,7 @@ public:
     }
 };
 
+static constexpr int MAX_STACK_CIPHER_LEN = 1024 * 64;
 struct UnHexImpl {
     static constexpr auto name = "unhex";
     using ReturnType = DataTypeString;
@@ -654,8 +655,16 @@ struct UnHexImpl {
                 continue;
             }
 
+            char dst_array[MAX_STACK_CIPHER_LEN];
+            char* dst = dst_array;
+
             int cipher_len = srclen / 2;
-            char dst[cipher_len];
+            std::unique_ptr<char[]> dst_uptr;
+            if (cipher_len > MAX_STACK_CIPHER_LEN) {
+                dst_uptr.reset(new char[cipher_len]);
+                dst = dst_uptr.get();
+            }
+
             int outlen = hex_decode(source, srclen, dst);
 
             if (outlen < 0) {
@@ -725,8 +734,16 @@ struct ToBase64Impl {
                 continue;
             }
 
+            char dst_array[MAX_STACK_CIPHER_LEN];
+            char* dst = dst_array;
+
             int cipher_len = (int)(4.0 * ceil((double)srclen / 3.0));
-            char dst[cipher_len];
+            std::unique_ptr<char[]> dst_uptr;
+            if (cipher_len > MAX_STACK_CIPHER_LEN) {
+                dst_uptr.reset(new char[cipher_len]);
+                dst = dst_uptr.get();
+            }
+
             int outlen = base64_encode((const unsigned char*)source, srclen, 
(unsigned char*)dst);
 
             if (outlen < 0) {
@@ -765,8 +782,15 @@ struct FromBase64Impl {
                 continue;
             }
 
+            char dst_array[MAX_STACK_CIPHER_LEN];
+            char* dst = dst_array;
+
             int cipher_len = srclen;
-            char dst[cipher_len];
+            std::unique_ptr<char[]> dst_uptr;
+            if (cipher_len > MAX_STACK_CIPHER_LEN) {
+                dst_uptr.reset(new char[cipher_len]);
+                dst = dst_uptr.get();
+            }
             int outlen = base64_decode(source, srclen, dst);
 
             if (outlen < 0) {
diff --git a/be/src/vec/functions/multiply.cpp 
b/be/src/vec/functions/multiply.cpp
index 79653910991..0dc9f4a410c 100644
--- a/be/src/vec/functions/multiply.cpp
+++ b/be/src/vec/functions/multiply.cpp
@@ -56,7 +56,8 @@ struct MultiplyImpl {
     static void vector_vector(const ColumnDecimal128::Container::value_type* 
__restrict a,
                               const ColumnDecimal128::Container::value_type* 
__restrict b,
                               ColumnDecimal128::Container::value_type* c, 
size_t size) {
-        int8 sgn[size];
+        auto sng_uptr = std::unique_ptr<int8[]>(new int8[size]);
+        int8* sgn = sng_uptr.get();
         auto max = DecimalV2Value::get_max_decimal();
         auto min = DecimalV2Value::get_min_decimal();
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to