This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new d8f911e0746 [Fix]set workload group for load channel (#40749) d8f911e0746 is described below commit d8f911e0746048f0dbbfd23a68dcc16bda150614 Author: wangbo <wan...@apache.org> AuthorDate: Thu Sep 12 23:13:25 2024 +0800 [Fix]set workload group for load channel (#40749) ## Proposed changes Set workload group for load channel, make its memory managed by Workload Group. --- be/src/runtime/load_channel.cpp | 36 ++++++++++++++++++++++++------- be/src/runtime/load_channel.h | 3 ++- be/src/runtime/load_channel_mgr.cpp | 6 +++++- be/src/runtime/thread_context.h | 4 ++++ be/src/vec/sink/writer/vtablet_writer.cpp | 15 +++++++++++++ be/src/vec/sink/writer/vtablet_writer.h | 2 ++ gensrc/proto/internal_service.proto | 1 + 7 files changed, 57 insertions(+), 10 deletions(-) diff --git a/be/src/runtime/load_channel.cpp b/be/src/runtime/load_channel.cpp index 99f0a0b3d5b..f8c11639719 100644 --- a/be/src/runtime/load_channel.cpp +++ b/be/src/runtime/load_channel.cpp @@ -36,7 +36,8 @@ namespace doris { bvar::Adder<int64_t> g_loadchannel_cnt("loadchannel_cnt"); LoadChannel::LoadChannel(const UniqueId& load_id, int64_t timeout_s, bool is_high_priority, - std::string sender_ip, int64_t backend_id, bool enable_profile) + std::string sender_ip, int64_t backend_id, bool enable_profile, + int64_t wg_id) : _load_id(load_id), _timeout_s(timeout_s), _is_high_priority(is_high_priority), @@ -46,16 +47,29 @@ LoadChannel::LoadChannel(const UniqueId& load_id, int64_t timeout_s, bool is_hig std::shared_ptr<QueryContext> query_context = ExecEnv::GetInstance()->fragment_mgr()->get_or_erase_query_ctx_with_lock( _load_id.to_thrift()); + std::shared_ptr<MemTrackerLimiter> mem_tracker = nullptr; + WorkloadGroupPtr wg_ptr = nullptr; + if (query_context != nullptr) { - _query_thread_context = {_load_id.to_thrift(), query_context->query_mem_tracker, - query_context->workload_group()}; + mem_tracker = query_context->query_mem_tracker; + wg_ptr = query_context->workload_group(); } else { - _query_thread_context = { - _load_id.to_thrift(), - MemTrackerLimiter::create_shared( - MemTrackerLimiter::Type::LOAD, - fmt::format("(FromLoadChannel)Load#Id={}", _load_id.to_string()))}; + // when memtable on sink is not enabled, load can not find queryctx + mem_tracker = MemTrackerLimiter::create_shared( + MemTrackerLimiter::Type::LOAD, + fmt::format("(FromLoadChannel)Load#Id={}", _load_id.to_string())); + if (wg_id > 0) { + WorkloadGroupPtr workload_group_ptr = + ExecEnv::GetInstance()->workload_group_mgr()->get_task_group_by_id(wg_id); + if (workload_group_ptr) { + wg_ptr = workload_group_ptr; + wg_ptr->add_mem_tracker_limiter(mem_tracker); + _need_release_memtracker = true; + } + } } + _query_thread_context = {_load_id.to_thrift(), mem_tracker, wg_ptr}; + g_loadchannel_cnt << 1; // _last_updated_time should be set before being inserted to // _load_channels in load_channel_mgr, or it may be erased @@ -71,6 +85,12 @@ LoadChannel::~LoadChannel() { rows_str << ", index id: " << entry.first << ", total_received_rows: " << entry.second.first << ", num_rows_filtered: " << entry.second.second; } + if (_need_release_memtracker) { + WorkloadGroupPtr wg_ptr = _query_thread_context.get_workload_group_ptr(); + if (wg_ptr) { + wg_ptr->remove_mem_tracker_limiter(_query_thread_context.get_memory_tracker()); + } + } LOG(INFO) << "load channel removed" << " load_id=" << _load_id << ", is high priority=" << _is_high_priority << ", sender_ip=" << _sender_ip << rows_str.str(); diff --git a/be/src/runtime/load_channel.h b/be/src/runtime/load_channel.h index 791e996574a..6fad8c536ec 100644 --- a/be/src/runtime/load_channel.h +++ b/be/src/runtime/load_channel.h @@ -46,7 +46,7 @@ class BaseTabletsChannel; class LoadChannel { public: LoadChannel(const UniqueId& load_id, int64_t timeout_s, bool is_high_priority, - std::string sender_ip, int64_t backend_id, bool enable_profile); + std::string sender_ip, int64_t backend_id, bool enable_profile, int64_t wg_id); ~LoadChannel(); // open a new load channel if not exist @@ -127,6 +127,7 @@ private: int64_t _backend_id; bool _enable_profile; + bool _need_release_memtracker = false; }; inline std::ostream& operator<<(std::ostream& os, LoadChannel& load_channel) { diff --git a/be/src/runtime/load_channel_mgr.cpp b/be/src/runtime/load_channel_mgr.cpp index d31ce1d9a7e..c53cade466b 100644 --- a/be/src/runtime/load_channel_mgr.cpp +++ b/be/src/runtime/load_channel_mgr.cpp @@ -94,9 +94,13 @@ Status LoadChannelMgr::open(const PTabletWriterOpenRequest& params) { int64_t channel_timeout_s = calc_channel_timeout_s(timeout_in_req_s); bool is_high_priority = (params.has_is_high_priority() && params.is_high_priority()); + int64_t wg_id = -1; + if (params.has_workload_group_id()) { + wg_id = params.workload_group_id(); + } channel.reset(new LoadChannel(load_id, channel_timeout_s, is_high_priority, params.sender_ip(), params.backend_id(), - params.enable_profile())); + params.enable_profile(), wg_id)); _load_channels.insert({load_id, channel}); } } diff --git a/be/src/runtime/thread_context.h b/be/src/runtime/thread_context.h index ea842c12028..19ebffa9354 100644 --- a/be/src/runtime/thread_context.h +++ b/be/src/runtime/thread_context.h @@ -402,6 +402,10 @@ public: #endif } + std::shared_ptr<MemTrackerLimiter> get_memory_tracker() { return query_mem_tracker; } + + WorkloadGroupPtr get_workload_group_ptr() { return wg_wptr.lock(); } + TUniqueId query_id; std::shared_ptr<MemTrackerLimiter> query_mem_tracker; std::weak_ptr<WorkloadGroup> wg_wptr; diff --git a/be/src/vec/sink/writer/vtablet_writer.cpp b/be/src/vec/sink/writer/vtablet_writer.cpp index b9eaf79616f..2aa16ae498f 100644 --- a/be/src/vec/sink/writer/vtablet_writer.cpp +++ b/be/src/vec/sink/writer/vtablet_writer.cpp @@ -64,6 +64,7 @@ #include "runtime/descriptors.h" #include "runtime/exec_env.h" #include "runtime/memory/memory_reclamation.h" +#include "runtime/query_context.h" #include "runtime/runtime_state.h" #include "runtime/thread_context.h" #include "service/backend_options.h" @@ -383,6 +384,16 @@ Status VNodeChannel::init(RuntimeState* state) { // a relatively large value to improve the import performance. _batch_size = std::max(_batch_size, 8192); + if (_state) { + QueryContext* query_ctx = _state->get_query_ctx(); + if (query_ctx) { + auto wg_ptr = query_ctx->workload_group(); + if (wg_ptr) { + _wg_id = wg_ptr->id(); + } + } + } + _inited = true; return Status::OK(); } @@ -426,6 +437,10 @@ void VNodeChannel::_open_internal(bool is_incremental) { request->set_txn_expiration(_parent->_txn_expiration); request->set_write_file_cache(_parent->_write_file_cache); + if (_wg_id > 0) { + request->set_workload_group_id(_wg_id); + } + auto open_callback = DummyBrpcCallback<PTabletWriterOpenResult>::create_shared(); auto open_closure = AutoReleaseClosure< PTabletWriterOpenRequest, diff --git a/be/src/vec/sink/writer/vtablet_writer.h b/be/src/vec/sink/writer/vtablet_writer.h index e7a89824ba3..52aa0f6b918 100644 --- a/be/src/vec/sink/writer/vtablet_writer.h +++ b/be/src/vec/sink/writer/vtablet_writer.h @@ -413,6 +413,8 @@ protected: // send block to slave BE rely on this. dont reconstruct it. std::shared_ptr<WriteBlockCallback<PTabletWriterAddBlockResult>> _send_block_callback = nullptr; + int64_t _wg_id = -1; + bool _is_incremental; }; diff --git a/gensrc/proto/internal_service.proto b/gensrc/proto/internal_service.proto index 4ac4fd24f3b..0d727f7f1c0 100644 --- a/gensrc/proto/internal_service.proto +++ b/gensrc/proto/internal_service.proto @@ -107,6 +107,7 @@ message PTabletWriterOpenRequest { optional bool write_file_cache = 17; optional string storage_vault_id = 18; optional int32 sender_id = 19; + optional int64 workload_group_id = 20; }; message PTabletWriterOpenResult { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org