This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new d8f911e0746 [Fix]set workload group for load channel (#40749)
d8f911e0746 is described below

commit d8f911e0746048f0dbbfd23a68dcc16bda150614
Author: wangbo <wan...@apache.org>
AuthorDate: Thu Sep 12 23:13:25 2024 +0800

    [Fix]set workload group for load channel (#40749)
    
    ## Proposed changes
    Set workload group for load channel, make its memory managed by Workload
    Group.
---
 be/src/runtime/load_channel.cpp           | 36 ++++++++++++++++++++++++-------
 be/src/runtime/load_channel.h             |  3 ++-
 be/src/runtime/load_channel_mgr.cpp       |  6 +++++-
 be/src/runtime/thread_context.h           |  4 ++++
 be/src/vec/sink/writer/vtablet_writer.cpp | 15 +++++++++++++
 be/src/vec/sink/writer/vtablet_writer.h   |  2 ++
 gensrc/proto/internal_service.proto       |  1 +
 7 files changed, 57 insertions(+), 10 deletions(-)

diff --git a/be/src/runtime/load_channel.cpp b/be/src/runtime/load_channel.cpp
index 99f0a0b3d5b..f8c11639719 100644
--- a/be/src/runtime/load_channel.cpp
+++ b/be/src/runtime/load_channel.cpp
@@ -36,7 +36,8 @@ namespace doris {
 bvar::Adder<int64_t> g_loadchannel_cnt("loadchannel_cnt");
 
 LoadChannel::LoadChannel(const UniqueId& load_id, int64_t timeout_s, bool 
is_high_priority,
-                         std::string sender_ip, int64_t backend_id, bool 
enable_profile)
+                         std::string sender_ip, int64_t backend_id, bool 
enable_profile,
+                         int64_t wg_id)
         : _load_id(load_id),
           _timeout_s(timeout_s),
           _is_high_priority(is_high_priority),
@@ -46,16 +47,29 @@ LoadChannel::LoadChannel(const UniqueId& load_id, int64_t 
timeout_s, bool is_hig
     std::shared_ptr<QueryContext> query_context =
             
ExecEnv::GetInstance()->fragment_mgr()->get_or_erase_query_ctx_with_lock(
                     _load_id.to_thrift());
+    std::shared_ptr<MemTrackerLimiter> mem_tracker = nullptr;
+    WorkloadGroupPtr wg_ptr = nullptr;
+
     if (query_context != nullptr) {
-        _query_thread_context = {_load_id.to_thrift(), 
query_context->query_mem_tracker,
-                                 query_context->workload_group()};
+        mem_tracker = query_context->query_mem_tracker;
+        wg_ptr = query_context->workload_group();
     } else {
-        _query_thread_context = {
-                _load_id.to_thrift(),
-                MemTrackerLimiter::create_shared(
-                        MemTrackerLimiter::Type::LOAD,
-                        fmt::format("(FromLoadChannel)Load#Id={}", 
_load_id.to_string()))};
+        // when memtable on sink is not enabled, load can not find queryctx
+        mem_tracker = MemTrackerLimiter::create_shared(
+                MemTrackerLimiter::Type::LOAD,
+                fmt::format("(FromLoadChannel)Load#Id={}", 
_load_id.to_string()));
+        if (wg_id > 0) {
+            WorkloadGroupPtr workload_group_ptr =
+                    
ExecEnv::GetInstance()->workload_group_mgr()->get_task_group_by_id(wg_id);
+            if (workload_group_ptr) {
+                wg_ptr = workload_group_ptr;
+                wg_ptr->add_mem_tracker_limiter(mem_tracker);
+                _need_release_memtracker = true;
+            }
+        }
     }
+    _query_thread_context = {_load_id.to_thrift(), mem_tracker, wg_ptr};
+
     g_loadchannel_cnt << 1;
     // _last_updated_time should be set before being inserted to
     // _load_channels in load_channel_mgr, or it may be erased
@@ -71,6 +85,12 @@ LoadChannel::~LoadChannel() {
         rows_str << ", index id: " << entry.first << ", total_received_rows: " 
<< entry.second.first
                  << ", num_rows_filtered: " << entry.second.second;
     }
+    if (_need_release_memtracker) {
+        WorkloadGroupPtr wg_ptr = 
_query_thread_context.get_workload_group_ptr();
+        if (wg_ptr) {
+            
wg_ptr->remove_mem_tracker_limiter(_query_thread_context.get_memory_tracker());
+        }
+    }
     LOG(INFO) << "load channel removed"
               << " load_id=" << _load_id << ", is high priority=" << 
_is_high_priority
               << ", sender_ip=" << _sender_ip << rows_str.str();
diff --git a/be/src/runtime/load_channel.h b/be/src/runtime/load_channel.h
index 791e996574a..6fad8c536ec 100644
--- a/be/src/runtime/load_channel.h
+++ b/be/src/runtime/load_channel.h
@@ -46,7 +46,7 @@ class BaseTabletsChannel;
 class LoadChannel {
 public:
     LoadChannel(const UniqueId& load_id, int64_t timeout_s, bool 
is_high_priority,
-                std::string sender_ip, int64_t backend_id, bool 
enable_profile);
+                std::string sender_ip, int64_t backend_id, bool 
enable_profile, int64_t wg_id);
     ~LoadChannel();
 
     // open a new load channel if not exist
@@ -127,6 +127,7 @@ private:
     int64_t _backend_id;
 
     bool _enable_profile;
+    bool _need_release_memtracker = false;
 };
 
 inline std::ostream& operator<<(std::ostream& os, LoadChannel& load_channel) {
diff --git a/be/src/runtime/load_channel_mgr.cpp 
b/be/src/runtime/load_channel_mgr.cpp
index d31ce1d9a7e..c53cade466b 100644
--- a/be/src/runtime/load_channel_mgr.cpp
+++ b/be/src/runtime/load_channel_mgr.cpp
@@ -94,9 +94,13 @@ Status LoadChannelMgr::open(const PTabletWriterOpenRequest& 
params) {
             int64_t channel_timeout_s = 
calc_channel_timeout_s(timeout_in_req_s);
             bool is_high_priority = (params.has_is_high_priority() && 
params.is_high_priority());
 
+            int64_t wg_id = -1;
+            if (params.has_workload_group_id()) {
+                wg_id = params.workload_group_id();
+            }
             channel.reset(new LoadChannel(load_id, channel_timeout_s, 
is_high_priority,
                                           params.sender_ip(), 
params.backend_id(),
-                                          params.enable_profile()));
+                                          params.enable_profile(), wg_id));
             _load_channels.insert({load_id, channel});
         }
     }
diff --git a/be/src/runtime/thread_context.h b/be/src/runtime/thread_context.h
index ea842c12028..19ebffa9354 100644
--- a/be/src/runtime/thread_context.h
+++ b/be/src/runtime/thread_context.h
@@ -402,6 +402,10 @@ public:
 #endif
     }
 
+    std::shared_ptr<MemTrackerLimiter> get_memory_tracker() { return 
query_mem_tracker; }
+
+    WorkloadGroupPtr get_workload_group_ptr() { return wg_wptr.lock(); }
+
     TUniqueId query_id;
     std::shared_ptr<MemTrackerLimiter> query_mem_tracker;
     std::weak_ptr<WorkloadGroup> wg_wptr;
diff --git a/be/src/vec/sink/writer/vtablet_writer.cpp 
b/be/src/vec/sink/writer/vtablet_writer.cpp
index b9eaf79616f..2aa16ae498f 100644
--- a/be/src/vec/sink/writer/vtablet_writer.cpp
+++ b/be/src/vec/sink/writer/vtablet_writer.cpp
@@ -64,6 +64,7 @@
 #include "runtime/descriptors.h"
 #include "runtime/exec_env.h"
 #include "runtime/memory/memory_reclamation.h"
+#include "runtime/query_context.h"
 #include "runtime/runtime_state.h"
 #include "runtime/thread_context.h"
 #include "service/backend_options.h"
@@ -383,6 +384,16 @@ Status VNodeChannel::init(RuntimeState* state) {
     // a relatively large value to improve the import performance.
     _batch_size = std::max(_batch_size, 8192);
 
+    if (_state) {
+        QueryContext* query_ctx = _state->get_query_ctx();
+        if (query_ctx) {
+            auto wg_ptr = query_ctx->workload_group();
+            if (wg_ptr) {
+                _wg_id = wg_ptr->id();
+            }
+        }
+    }
+
     _inited = true;
     return Status::OK();
 }
@@ -426,6 +437,10 @@ void VNodeChannel::_open_internal(bool is_incremental) {
     request->set_txn_expiration(_parent->_txn_expiration);
     request->set_write_file_cache(_parent->_write_file_cache);
 
+    if (_wg_id > 0) {
+        request->set_workload_group_id(_wg_id);
+    }
+
     auto open_callback = 
DummyBrpcCallback<PTabletWriterOpenResult>::create_shared();
     auto open_closure = AutoReleaseClosure<
             PTabletWriterOpenRequest,
diff --git a/be/src/vec/sink/writer/vtablet_writer.h 
b/be/src/vec/sink/writer/vtablet_writer.h
index e7a89824ba3..52aa0f6b918 100644
--- a/be/src/vec/sink/writer/vtablet_writer.h
+++ b/be/src/vec/sink/writer/vtablet_writer.h
@@ -413,6 +413,8 @@ protected:
     // send block to slave BE rely on this. dont reconstruct it.
     std::shared_ptr<WriteBlockCallback<PTabletWriterAddBlockResult>> 
_send_block_callback = nullptr;
 
+    int64_t _wg_id = -1;
+
     bool _is_incremental;
 };
 
diff --git a/gensrc/proto/internal_service.proto 
b/gensrc/proto/internal_service.proto
index 4ac4fd24f3b..0d727f7f1c0 100644
--- a/gensrc/proto/internal_service.proto
+++ b/gensrc/proto/internal_service.proto
@@ -107,6 +107,7 @@ message PTabletWriterOpenRequest {
     optional bool write_file_cache = 17;
     optional string storage_vault_id = 18;
     optional int32 sender_id = 19;
+    optional int64 workload_group_id = 20;
 };
 
 message PTabletWriterOpenResult {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to