This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-4.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 7b8c01a62d0a4279fd6bf4647eebb85930b9911d Author: HappenLee <[email protected]> AuthorDate: Fri Mar 13 19:44:36 2026 +0800 [Opt](exec) cherry pick the opt code #60137 #59492 #59446 #58728 (#61282) ### What problem does this PR solve? cherry pick the opt code #60137 #59492 #59446 #58728 Problem Summary: ### Release note None ### Check List (For Author) - Test <!-- At least one of them must be included. --> - [ ] Regression test - [ ] Unit Test - [ ] Manual test (add detailed scripts or steps below) - [ ] No need to test or manual test. Explain why: - [ ] This is a refactor/code format and no logic has been changed. - [ ] Previous test can cover this change. - [ ] No code files have been changed. - [ ] Other reason <!-- Add your reason? --> - Behavior changed: - [ ] No. - [ ] Yes. <!-- Explain the behavior change --> - Does this need documentation? - [ ] No. - [ ] Yes. <!-- Add document PR link here. eg: https://github.com/apache/doris-website/pull/1214 --> ### Check List (For Reviewer who merge this PR) - [ ] Confirm the release note - [ ] Confirm test cases - [ ] Confirm document - [ ] Add branch pick label <!-- Add branch pick label that this PR should merge into --> --- be/src/olap/comparison_predicate.h | 15 +- .../exec/streaming_aggregation_operator.cpp | 301 ++++++++++++++++++--- .../pipeline/exec/streaming_aggregation_operator.h | 78 +++++- be/src/vec/exec/scan/scanner.cpp | 35 ++- be/src/vec/exec/scan/scanner.h | 12 + be/src/vec/exprs/vcompound_pred.h | 67 +++-- .../operator/streaming_agg_operator_test.cpp | 20 +- .../glue/translator/PhysicalPlanTranslator.java | 3 +- .../nereids_tpch_p0/tpch/push_topn_to_agg.groovy | 5 +- 9 files changed, 454 insertions(+), 82 deletions(-) diff --git a/be/src/olap/comparison_predicate.h b/be/src/olap/comparison_predicate.h index 6992112b63f..1ef691b1283 100644 --- a/be/src/olap/comparison_predicate.h +++ b/be/src/olap/comparison_predicate.h @@ -380,8 +380,8 @@ public: } template <bool is_and> - void __attribute__((flatten)) - _evaluate_vec_internal(const vectorized::IColumn& column, uint16_t size, bool* flags) const { + void __attribute__((flatten)) _evaluate_vec_internal(const vectorized::IColumn& column, + uint16_t size, bool* flags) const { uint16_t current_evaluated_rows = 0; uint16_t current_passed_rows = 0; if (_can_ignore()) { @@ -579,9 +579,10 @@ private: } template <bool is_nullable, bool is_and, typename TArray, typename TValue> - void __attribute__((flatten)) - _base_loop_vec(uint16_t size, bool* __restrict bflags, const uint8_t* __restrict null_map, - const TArray* __restrict data_array, const TValue& value) const { + void __attribute__((flatten)) _base_loop_vec(uint16_t size, bool* __restrict bflags, + const uint8_t* __restrict null_map, + const TArray* __restrict data_array, + const TValue& value) const { //uint8_t helps compiler to generate vectorized code auto* flags = reinterpret_cast<uint8_t*>(bflags); if constexpr (is_and) { @@ -696,8 +697,8 @@ private: } } - int32_t __attribute__((flatten)) - _find_code_from_dictionary_column(const vectorized::ColumnDictI32& column) const { + int32_t __attribute__((flatten)) _find_code_from_dictionary_column( + const vectorized::ColumnDictI32& column) const { static_assert(is_string_type(Type), "Only string type predicate can use dictionary column."); int32_t code = 0; diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.cpp b/be/src/pipeline/exec/streaming_aggregation_operator.cpp index 6c0506e600c..6c0c412f819 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.cpp +++ b/be/src/pipeline/exec/streaming_aggregation_operator.cpp @@ -99,6 +99,8 @@ Status StreamingAggLocalState::init(RuntimeState* state, LocalStateInfo& info) { _insert_values_to_column_timer = ADD_TIMER(Base::custom_profile(), "InsertValuesToColumnTime"); _deserialize_data_timer = ADD_TIMER(Base::custom_profile(), "DeserializeAndMergeTime"); _hash_table_compute_timer = ADD_TIMER(Base::custom_profile(), "HashTableComputeTime"); + _hash_table_limit_compute_timer = + ADD_TIMER(Base::custom_profile(), "HashTableLimitComputeTime"); _hash_table_emplace_timer = ADD_TIMER(Base::custom_profile(), "HashTableEmplaceTime"); _hash_table_input_counter = ADD_COUNTER(Base::custom_profile(), "HashTableInputCount", TUnit::UNIT); @@ -152,16 +154,10 @@ Status StreamingAggLocalState::open(RuntimeState* state) { }}, _agg_data->method_variant); - if (p._is_merge || p._needs_finalize) { - return Status::InvalidArgument( - "StreamingAggLocalState only support no merge and no finalize, " - "but got is_merge={}, needs_finalize={}", - p._is_merge, p._needs_finalize); - } - - _should_limit_output = p._limit != -1 && // has limit - (!p._have_conjuncts) && // no having conjunct - p._needs_finalize; // agg's finalize step + limit = p._sort_limit; + do_sort_limit = p._do_sort_limit; + null_directions = p._null_directions; + order_directions = p._order_directions; return Status::OK(); } @@ -316,23 +312,22 @@ bool StreamingAggLocalState::_should_not_do_pre_agg(size_t rows) { const auto spill_streaming_agg_mem_limit = p._spill_streaming_agg_mem_limit; const bool used_too_much_memory = spill_streaming_agg_mem_limit > 0 && _memory_usage() > spill_streaming_agg_mem_limit; - std::visit( - vectorized::Overload { - [&](std::monostate& arg) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) { - auto& hash_tbl = *agg_method.hash_table; - /// If too much memory is used during the pre-aggregation stage, - /// it is better to output the data directly without performing further aggregation. - // do not try to do agg, just init and serialize directly return the out_block - if (used_too_much_memory || (hash_tbl.add_elem_size_overflow(rows) && - !_should_expand_preagg_hash_tables())) { - SCOPED_TIMER(_streaming_agg_timer); - ret_flag = true; - } - }}, - _agg_data->method_variant); + std::visit(vectorized::Overload { + [&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) { + auto& hash_tbl = *agg_method.hash_table; + /// If too much memory is used during the pre-aggregation stage, + /// it is better to output the data directly without performing further aggregation. + // do not try to do agg, just init and serialize directly return the out_block + if (used_too_much_memory || (hash_tbl.add_elem_size_overflow(rows) && + !_should_expand_preagg_hash_tables())) { + SCOPED_TIMER(_streaming_agg_timer); + ret_flag = true; + } + }}, + _agg_data->method_variant); return ret_flag; } @@ -363,6 +358,30 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B _places.resize(rows); if (_should_not_do_pre_agg(rows)) { + if (limit > 0) { + DCHECK(do_sort_limit); + if (need_do_sort_limit == -1) { + const size_t hash_table_size = _get_hash_table_size(); + need_do_sort_limit = hash_table_size >= limit ? 1 : 0; + if (need_do_sort_limit == 1) { + build_limit_heap(hash_table_size); + } + } + + if (need_do_sort_limit == 1) { + if (_do_limit_filter(rows, key_columns)) { + bool need_filter = std::find(need_computes.begin(), need_computes.end(), 1) != + need_computes.end(); + if (need_filter) { + _add_limit_heap_top(key_columns, rows); + vectorized::Block::filter_block_internal(in_block, need_computes); + rows = (uint32_t)in_block->rows(); + } else { + return Status::OK(); + } + } + } + } bool mem_reuse = p._make_nullable_keys.empty() && out_block->mem_reuse(); std::vector<vectorized::DataTypePtr> data_types; @@ -404,12 +423,23 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::vectorized::B } } } else { - _emplace_into_hash_table(_places.data(), key_columns, rows); + bool need_agg = true; + if (need_do_sort_limit != 1) { + _emplace_into_hash_table(_places.data(), key_columns, rows); + } else { + need_agg = _emplace_into_hash_table_limit(_places.data(), in_block, key_columns, rows); + } - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( - in_block, p._offsets_of_aggregate_states[i], _places.data(), _agg_arena_pool, - _should_expand_hash_table)); + if (need_agg) { + for (int i = 0; i < _aggregate_evaluators.size(); ++i) { + RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( + in_block, p._offsets_of_aggregate_states[i], _places.data(), + _agg_arena_pool, _should_expand_hash_table)); + } + if (limit > 0 && need_do_sort_limit == -1 && _get_hash_table_size() >= limit) { + need_do_sort_limit = 1; + build_limit_heap(_get_hash_table_size()); + } } } @@ -561,6 +591,183 @@ void StreamingAggLocalState::_destroy_agg_status(vectorized::AggregateDataPtr da } } +vectorized::MutableColumns StreamingAggLocalState::_get_keys_hash_table() { + return std::visit( + vectorized::Overload { + [&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return vectorized::MutableColumns(); + }, + [&](auto&& agg_method) -> vectorized::MutableColumns { + vectorized::MutableColumns key_columns; + for (int i = 0; i < _probe_expr_ctxs.size(); ++i) { + key_columns.emplace_back( + _probe_expr_ctxs[i]->root()->data_type()->create_column()); + } + auto& data = *agg_method.hash_table; + bool has_null_key = data.has_null_key_data(); + const auto size = data.size() - has_null_key; + using KeyType = std::decay_t<decltype(agg_method)>::Key; + std::vector<KeyType> keys(size); + + uint32_t num_rows = 0; + auto iter = _aggregate_data_container->begin(); + { + while (iter != _aggregate_data_container->end()) { + keys[num_rows] = iter.get_key<KeyType>(); + ++iter; + ++num_rows; + } + } + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + if (has_null_key) { + key_columns[0]->insert_data(nullptr, 0); + } + return key_columns; + }}, + _agg_data->method_variant); +} + +void StreamingAggLocalState::build_limit_heap(size_t hash_table_size) { + limit_columns = _get_keys_hash_table(); + for (size_t i = 0; i < hash_table_size; ++i) { + limit_heap.emplace(i, limit_columns, order_directions, null_directions); + } + while (hash_table_size > limit) { + limit_heap.pop(); + hash_table_size--; + } + limit_columns_min = limit_heap.top()._row_id; +} + +void StreamingAggLocalState::_add_limit_heap_top(vectorized::ColumnRawPtrs& key_columns, + size_t rows) { + for (int i = 0; i < rows; ++i) { + if (cmp_res[i] == 1 && need_computes[i]) { + for (int j = 0; j < key_columns.size(); ++j) { + limit_columns[j]->insert_from(*key_columns[j], i); + } + limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions, + null_directions); + limit_heap.pop(); + limit_columns_min = limit_heap.top()._row_id; + break; + } + } +} + +void StreamingAggLocalState::_refresh_limit_heap(size_t i, vectorized::ColumnRawPtrs& key_columns) { + for (int j = 0; j < key_columns.size(); ++j) { + limit_columns[j]->insert_from(*key_columns[j], i); + } + limit_heap.emplace(limit_columns[0]->size() - 1, limit_columns, order_directions, + null_directions); + limit_heap.pop(); + limit_columns_min = limit_heap.top()._row_id; +} + +bool StreamingAggLocalState::_emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places, + vectorized::Block* block, + vectorized::ColumnRawPtrs& key_columns, + uint32_t num_rows) { + return std::visit( + vectorized::Overload { + [&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return true; + }, + [&](auto&& agg_method) -> bool { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t<decltype(agg_method)>; + using AggState = typename HashMethodType::State; + + bool need_filter = _do_limit_filter(num_rows, key_columns); + if (auto need_agg = + std::find(need_computes.begin(), need_computes.end(), 1); + need_agg != need_computes.end()) { + if (need_filter) { + vectorized::Block::filter_block_internal(block, need_computes); + num_rows = (uint32_t)block->rows(); + } + + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + size_t i = 0; + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + try { + HashMethodType::try_presis_key_and_origin(key, origin, + _agg_arena_pool); + auto mapped = _aggregate_data_container->append_data(origin); + auto st = _create_agg_status(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + ctor(key, mapped); + _refresh_limit_heap(i, key_columns); + } catch (...) { + // Exception-safety - if it can not allocate memory or create status, + // the destructors will not be called. + ctor(key, nullptr); + throw; + } + }; + + auto creator_for_null_key = [&](auto& mapped) { + mapped = _agg_arena_pool.aligned_alloc( + Base::_parent->template cast<StreamingAggOperatorX>() + ._total_size_of_aggregate_states, + Base::_parent->template cast<StreamingAggOperatorX>() + ._align_aggregate_states); + auto st = _create_agg_status(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + _refresh_limit_heap(i, key_columns); + }; + + SCOPED_TIMER(_hash_table_emplace_timer); + for (i = 0; i < num_rows; ++i) { + places[i] = *agg_method.lazy_emplace(state, i, creator, + creator_for_null_key); + } + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + return true; + } + return false; + }}, + _agg_data->method_variant); +} + +bool StreamingAggLocalState::_do_limit_filter(size_t num_rows, + vectorized::ColumnRawPtrs& key_columns) { + SCOPED_TIMER(_hash_table_limit_compute_timer); + if (num_rows) { + cmp_res.resize(num_rows); + need_computes.resize(num_rows); + memset(need_computes.data(), 0, need_computes.size()); + memset(cmp_res.data(), 0, cmp_res.size()); + + const auto key_size = null_directions.size(); + for (int i = 0; i < key_size; i++) { + key_columns[i]->compare_internal(limit_columns_min, *limit_columns[i], + null_directions[i], order_directions[i], cmp_res, + need_computes.data()); + } + + auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, size_t rows) { + for (size_t i = 0; i < rows; ++i) { + computes[i] = computes[i] == res[i]; + } + }; + set_computes_arr(cmp_res.data(), need_computes.data(), num_rows); + + return std::find(need_computes.begin(), need_computes.end(), 0) != need_computes.end(); + } + + return false; +} + void StreamingAggLocalState::_emplace_into_hash_table(vectorized::AggregateDataPtr* places, vectorized::ColumnRawPtrs& key_columns, const uint32_t num_rows) { @@ -615,7 +822,6 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id, _intermediate_tuple_id(tnode.agg_node.intermediate_tuple_id), _output_tuple_id(tnode.agg_node.output_tuple_id), _needs_finalize(tnode.agg_node.need_finalize), - _is_merge(false), _is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase), _have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()), _agg_fn_output_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples) {} @@ -673,8 +879,33 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state) } const auto& agg_functions = tnode.agg_node.aggregate_functions; - _is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(), - [](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; }); + auto is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(), + [](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; }); + if (is_merge || _needs_finalize) { + return Status::InvalidArgument( + "StreamingAggLocalState only support no merge and no finalize, " + "but got is_merge={}, needs_finalize={}", + is_merge, _needs_finalize); + } + + // Handle sort limit + if (tnode.agg_node.__isset.agg_sort_info_by_group_key) { + _sort_limit = _limit; + _limit = -1; + _do_sort_limit = true; + const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key; + DCHECK_EQ(agg_sort_info.nulls_first.size(), agg_sort_info.is_asc_order.size()); + + const size_t order_by_key_size = agg_sort_info.is_asc_order.size(); + _order_directions.resize(order_by_key_size); + _null_directions.resize(order_by_key_size); + for (int i = 0; i < order_by_key_size; ++i) { + _order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1; + _null_directions[i] = + agg_sort_info.nulls_first[i] ? -_order_directions[i] : _order_directions[i]; + } + } + _op_name = "STREAMING_AGGREGATION_OPERATOR"; return Status::OK(); } diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.h b/be/src/pipeline/exec/streaming_aggregation_operator.h index d5b09c7eb25..a9e8cc54ba8 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.h +++ b/be/src/pipeline/exec/streaming_aggregation_operator.h @@ -48,6 +48,7 @@ public: Status do_pre_agg(RuntimeState* state, vectorized::Block* input_block, vectorized::Block* output_block); void make_nullable_output_key(vectorized::Block* block); + void build_limit_heap(size_t hash_table_size); private: friend class StreamingAggOperatorX; @@ -55,6 +56,10 @@ private: friend class StatefulOperatorX; size_t _memory_usage() const; + void _add_limit_heap_top(vectorized::ColumnRawPtrs& key_columns, size_t rows); + bool _do_limit_filter(size_t num_rows, vectorized::ColumnRawPtrs& key_columns); + void _refresh_limit_heap(size_t i, vectorized::ColumnRawPtrs& key_columns); + Status _pre_agg_with_serialized_key(doris::vectorized::Block* in_block, doris::vectorized::Block* out_block); bool _should_expand_preagg_hash_tables(); @@ -68,11 +73,15 @@ private: bool* eos); void _emplace_into_hash_table(vectorized::AggregateDataPtr* places, vectorized::ColumnRawPtrs& key_columns, const uint32_t num_rows); + bool _emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places, + vectorized::Block* block, + vectorized::ColumnRawPtrs& key_columns, uint32_t num_rows); Status _create_agg_status(vectorized::AggregateDataPtr data); size_t _get_hash_table_size(); RuntimeProfile::Counter* _streaming_agg_timer = nullptr; RuntimeProfile::Counter* _hash_table_compute_timer = nullptr; + RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr; RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr; RuntimeProfile::Counter* _hash_table_input_counter = nullptr; RuntimeProfile::Counter* _build_timer = nullptr; @@ -95,10 +104,70 @@ private: // group by k1,k2 vectorized::VExprContextSPtrs _probe_expr_ctxs; std::unique_ptr<AggregateDataContainer> _aggregate_data_container = nullptr; - bool _should_limit_output = false; bool _reach_limit = false; size_t _input_num_rows = 0; + int64_t limit = -1; + int need_do_sort_limit = -1; + bool do_sort_limit = false; + vectorized::MutableColumns limit_columns; + int limit_columns_min = -1; + vectorized::PaddedPODArray<uint8_t> need_computes; + std::vector<uint8_t> cmp_res; + std::vector<int> order_directions; + std::vector<int> null_directions; + + struct HeapLimitCursor { + HeapLimitCursor(int row_id, vectorized::MutableColumns& limit_columns, + std::vector<int>& order_directions, std::vector<int>& null_directions) + : _row_id(row_id), + _limit_columns(limit_columns), + _order_directions(order_directions), + _null_directions(null_directions) {} + + HeapLimitCursor(const HeapLimitCursor& other) = default; + + HeapLimitCursor(HeapLimitCursor&& other) noexcept + : _row_id(other._row_id), + _limit_columns(other._limit_columns), + _order_directions(other._order_directions), + _null_directions(other._null_directions) {} + + HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept { + _row_id = other._row_id; + return *this; + } + + HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept { + _row_id = other._row_id; + return *this; + } + + bool operator<(const HeapLimitCursor& rhs) const { + for (int i = 0; i < _limit_columns.size(); ++i) { + const auto& _limit_column = _limit_columns[i]; + auto res = _limit_column->compare_at(_row_id, rhs._row_id, *_limit_column, + _null_directions[i]) * + _order_directions[i]; + if (res < 0) { + return true; + } else if (res > 0) { + return false; + } + } + return false; + } + + int _row_id; + vectorized::MutableColumns& _limit_columns; + std::vector<int>& _order_directions; + std::vector<int>& _null_directions; + }; + + std::priority_queue<HeapLimitCursor> limit_heap; + + vectorized::MutableColumns _get_keys_hash_table(); + vectorized::PODArray<vectorized::AggregateDataPtr> _places; std::vector<char> _deserialize_buffer; @@ -185,7 +254,6 @@ private: TupleId _output_tuple_id; TupleDescriptor* _output_tuple_desc = nullptr; bool _needs_finalize; - bool _is_merge; const bool _is_first_phase; size_t _align_aggregate_states = 1; /// The offset to the n-th aggregate function in a row of aggregate functions. @@ -202,6 +270,12 @@ private: std::vector<size_t> _make_nullable_keys; bool _have_conjuncts; RowDescriptor _agg_fn_output_row_descriptor; + // For sort limit + bool _do_sort_limit = false; + int64_t _sort_limit = -1; + std::vector<int> _order_directions; + std::vector<int> _null_directions; + std::vector<TExpr> _partition_exprs; }; diff --git a/be/src/vec/exec/scan/scanner.cpp b/be/src/vec/exec/scan/scanner.cpp index 483aa545eec..aad88f1b41b 100644 --- a/be/src/vec/exec/scan/scanner.cpp +++ b/be/src/vec/exec/scan/scanner.cpp @@ -79,8 +79,39 @@ Status Scanner::init(RuntimeState* state, const VExprContextSPtrs& conjuncts) { Status Scanner::get_block_after_projects(RuntimeState* state, vectorized::Block* block, bool* eos) { auto& row_descriptor = _local_state->_parent->row_descriptor(); if (_output_row_descriptor) { - _origin_block.clear_column_data(row_descriptor.num_materialized_slots()); - RETURN_IF_ERROR(get_block(state, &_origin_block, eos)); + if (_alreay_eos) { + *eos = true; + _padding_block.swap(_origin_block); + } else { + _origin_block.clear_column_data(row_descriptor.num_materialized_slots()); + const auto min_batch_size = std::max(state->batch_size() / 2, 1); + while (_padding_block.rows() < min_batch_size && !*eos) { + RETURN_IF_ERROR(get_block(state, &_origin_block, eos)); + if (_origin_block.rows() >= min_batch_size) { + break; + } + + if (_origin_block.rows() + _padding_block.rows() <= state->batch_size()) { + RETURN_IF_ERROR(_merge_padding_block()); + _origin_block.clear_column_data(row_descriptor.num_materialized_slots()); + } else { + if (_origin_block.rows() < _padding_block.rows()) { + _padding_block.swap(_origin_block); + } + break; + } + } + } + + // first output the origin block change eos = false, next time output padding block + // set the eos to true + if (*eos && !_padding_block.empty() && !_origin_block.empty()) { + _alreay_eos = true; + *eos = false; + } + if (_origin_block.empty() && !_padding_block.empty()) { + _padding_block.swap(_origin_block); + } return _do_projections(&_origin_block, block); } else { return get_block(state, block, eos); diff --git a/be/src/vec/exec/scan/scanner.h b/be/src/vec/exec/scan/scanner.h index dec0349c8fc..9840eac1fd8 100644 --- a/be/src/vec/exec/scan/scanner.h +++ b/be/src/vec/exec/scan/scanner.h @@ -107,6 +107,16 @@ protected: // Subclass should implement this to return data. virtual Status _get_block_impl(RuntimeState* state, Block* block, bool* eof) = 0; + Status _merge_padding_block() { + if (_padding_block.empty()) { + _padding_block.swap(_origin_block); + } else if (_origin_block.rows()) { + RETURN_IF_ERROR( + MutableBlock::build_mutable_block(&_padding_block).merge(_origin_block)); + } + return Status::OK(); + } + // Update the counters before closing this scanner virtual void _collect_profile_before_close(); @@ -209,6 +219,8 @@ protected: // Used in common subexpression elimination to compute intermediate results. std::vector<vectorized::VExprContextSPtrs> _intermediate_projections; vectorized::Block _origin_block; + vectorized::Block _padding_block; + bool _alreay_eos = false; VExprContextSPtrs _common_expr_ctxs_push_down; diff --git a/be/src/vec/exprs/vcompound_pred.h b/be/src/vec/exprs/vcompound_pred.h index c3925786261..2d39319cedc 100644 --- a/be/src/vec/exprs/vcompound_pred.h +++ b/be/src/vec/exprs/vcompound_pred.h @@ -240,15 +240,7 @@ public: result_column = std::move(col_res); } - if constexpr (is_and_op) { - for (size_t i = 0; i < size; ++i) { - lhs_data_column[i] &= rhs_data_column[i]; - } - } else { - for (size_t i = 0; i < size; ++i) { - lhs_data_column[i] |= rhs_data_column[i]; - } - } + do_not_null_pred<is_and_op>(lhs_data_column, rhs_data_column, size); }; auto vector_vector_null = [&]<bool is_and_op>() { auto col_res = ColumnUInt8::create(size); @@ -265,19 +257,9 @@ public: auto* __restrict lhs_data_column_tmp = lhs_data_column; auto* __restrict rhs_data_column_tmp = rhs_data_column; - if constexpr (is_and_op) { - for (size_t i = 0; i < size; ++i) { - res_nulls[i] = apply_and_null(lhs_data_column_tmp[i], lhs_null_map_tmp[i], - rhs_data_column_tmp[i], rhs_null_map_tmp[i]); - res_datas[i] = lhs_data_column_tmp[i] & rhs_data_column_tmp[i]; - } - } else { - for (size_t i = 0; i < size; ++i) { - res_nulls[i] = apply_or_null(lhs_data_column_tmp[i], lhs_null_map_tmp[i], - rhs_data_column_tmp[i], rhs_null_map_tmp[i]); - res_datas[i] = lhs_data_column_tmp[i] | rhs_data_column_tmp[i]; - } - } + do_null_pred<is_and_op>(lhs_data_column_tmp, lhs_null_map_tmp, rhs_data_column_tmp, + rhs_null_map_tmp, res_datas, res_nulls, size); + result_column = ColumnNullable::create(std::move(col_res), std::move(col_nulls)); }; @@ -358,6 +340,47 @@ private: return (l_null & r_null) | (r_null & (r_null ^ a)) | (l_null & (l_null ^ b)); } + template <bool is_and> + void static do_not_null_pred(uint8_t* __restrict lhs, uint8_t* __restrict rhs, size_t size) { +#ifdef NDEBUG +#if defined(__clang__) +#pragma clang loop vectorize(enable) +#elif defined(__GNUC__) && (__GNUC__ >= 5) +#pragma GCC ivdep +#endif +#endif + for (size_t i = 0; i < size; ++i) { + if constexpr (is_and) { + lhs[i] &= rhs[i]; + } else { + lhs[i] |= rhs[i]; + } + } + } + + template <bool is_and> + void static do_null_pred(uint8_t* __restrict lhs_data, uint8_t* __restrict lhs_null, + uint8_t* __restrict rhs_data, uint8_t* __restrict rhs_null, + uint8_t* __restrict res_data, uint8_t* __restrict res_null, + size_t size) { +#ifdef NDEBUG +#if defined(__clang__) +#pragma clang loop vectorize(enable) +#elif defined(__GNUC__) && (__GNUC__ >= 5) +#pragma GCC ivdep +#endif +#endif + for (size_t i = 0; i < size; ++i) { + if constexpr (is_and) { + res_null[i] = apply_and_null(lhs_data[i], lhs_null[i], rhs_data[i], rhs_null[i]); + res_data[i] = lhs_data[i] & rhs_data[i]; + } else { + res_null[i] = apply_or_null(lhs_data[i], lhs_null[i], rhs_data[i], rhs_null[i]); + res_data[i] = lhs_data[i] | rhs_data[i]; + } + } + } + bool _has_const_child() const { return std::ranges::any_of(_children, [](const VExprSPtr& arg) -> bool { return arg->is_constant(); }); diff --git a/be/test/pipeline/operator/streaming_agg_operator_test.cpp b/be/test/pipeline/operator/streaming_agg_operator_test.cpp index 91ca56572be..664984db34a 100644 --- a/be/test/pipeline/operator/streaming_agg_operator_test.cpp +++ b/be/test/pipeline/operator/streaming_agg_operator_test.cpp @@ -109,7 +109,6 @@ TEST_F(StreamingAggOperatorTest, test1) { false)); op->_pool = &pool; op->_needs_finalize = false; - op->_is_merge = false; EXPECT_TRUE(op->set_child(child_op)); @@ -157,7 +156,9 @@ TEST_F(StreamingAggOperatorTest, test1) { EXPECT_TRUE(op->need_more_input_data(state.get())); } - { EXPECT_TRUE(local_state->close(state.get()).ok()); } + { + EXPECT_TRUE(local_state->close(state.get()).ok()); + } } TEST_F(StreamingAggOperatorTest, test2) { @@ -166,7 +167,6 @@ TEST_F(StreamingAggOperatorTest, test2) { false)); op->_pool = &pool; op->_needs_finalize = false; - op->_is_merge = false; EXPECT_TRUE(op->set_child(child_op)); @@ -234,7 +234,9 @@ TEST_F(StreamingAggOperatorTest, test2) { EXPECT_EQ(block.rows(), 3); } - { EXPECT_TRUE(local_state->close(state.get()).ok()); } + { + EXPECT_TRUE(local_state->close(state.get()).ok()); + } } TEST_F(StreamingAggOperatorTest, test3) { @@ -243,7 +245,6 @@ TEST_F(StreamingAggOperatorTest, test3) { false)); op->_pool = &pool; op->_needs_finalize = false; - op->_is_merge = false; EXPECT_TRUE(op->set_child(child_op)); @@ -314,7 +315,9 @@ TEST_F(StreamingAggOperatorTest, test3) { EXPECT_EQ(block.rows(), 3); } - { EXPECT_TRUE(local_state->close(state.get()).ok()); } + { + EXPECT_TRUE(local_state->close(state.get()).ok()); + } } TEST_F(StreamingAggOperatorTest, test4) { @@ -323,7 +326,6 @@ TEST_F(StreamingAggOperatorTest, test4) { std::make_shared<DataTypeBitMap>(), false)); op->_pool = &pool; op->_needs_finalize = false; - op->_is_merge = false; EXPECT_TRUE(op->set_child(child_op)); @@ -406,7 +408,9 @@ TEST_F(StreamingAggOperatorTest, test4) { // << "Expected: " << res_block.dump_data() << ", but got: " << block.dump_data(); } - { EXPECT_TRUE(local_state->close(state.get()).ok()); } + { + EXPECT_TRUE(local_state->close(state.get()).ok()); + } } } // namespace doris::pipeline diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index cf40d270365..33d15407902 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -346,8 +346,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla if (upstreamFragment.getPlanRoot() instanceof AggregationNode && upstream instanceof PhysicalHashAggregate) { PhysicalHashAggregate<?> hashAggregate = (PhysicalHashAggregate<?>) upstream; if (hashAggregate.getAggPhase() == AggPhase.LOCAL - && hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER - && hashAggregate.getTopnPushInfo() == null) { + && hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER) { AggregationNode aggregationNode = (AggregationNode) upstreamFragment.getPlanRoot(); aggregationNode.setUseStreamingPreagg(hashAggregate.isMaybeUsingStream()); } diff --git a/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy b/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy index 06975eef5ea..5e694b4781d 100644 --- a/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy +++ b/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy @@ -32,7 +32,6 @@ suite("push_topn_to_agg") { explain{ sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey limit 4;" multiContains ("sortByGroupKey:true", 2) - notContains("STREAMING") } // when apply this opt, trun off STREAMING @@ -40,14 +39,12 @@ suite("push_topn_to_agg") { explain{ sql "select sum(c_custkey), c_name from customer group by c_name limit 6;" multiContains ("sortByGroupKey:true", 2) - notContains("STREAMING") } // topn -> agg explain{ sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey order by o_custkey limit 8;" multiContains ("sortByGroupKey:true", 2) - notContains("STREAMING") } // order keys are part of group keys, @@ -185,4 +182,4 @@ suite("push_topn_to_agg") { | planed with unknown column statistics | +--------------------------------------------------------------------------------+ **/ -} \ No newline at end of file +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
