This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 657a29fd9e7af885b14df7b50dc7ee9d1442199d
Author: Gabriel <gabrielleeb...@gmail.com>
AuthorDate: Thu Apr 18 14:27:40 2024 +0800

    [refactor](partitioner) refine get channel id logics (#33765)
---
 be/src/pipeline/exec/exchange_sink_operator.cpp         |  4 ++--
 be/src/pipeline/exec/exchange_sink_operator.h           |  3 +--
 .../exec/partitioned_hash_join_probe_operator.cpp       |  2 +-
 .../exec/partitioned_hash_join_sink_operator.cpp        |  4 ++--
 .../pipeline_x/local_exchange/local_exchanger.cpp       |  2 +-
 be/src/vec/runtime/partitioner.h                        | 17 +++++++++++++++--
 be/src/vec/sink/vdata_stream_sender.cpp                 |  4 ++--
 7 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/be/src/pipeline/exec/exchange_sink_operator.cpp 
b/be/src/pipeline/exec/exchange_sink_operator.cpp
index 2c37f24eac4..580e8e525d6 100644
--- a/be/src/pipeline/exec/exchange_sink_operator.cpp
+++ b/be/src/pipeline/exec/exchange_sink_operator.cpp
@@ -493,11 +493,11 @@ Status ExchangeSinkOperatorX::sink(RuntimeState* state, 
vectorized::Block* block
         if (_part_type == TPartitionType::HASH_PARTITIONED) {
             RETURN_IF_ERROR(channel_add_rows(
                     state, local_state.channels, local_state._partition_count,
-                    (uint32_t*)local_state._partitioner->get_channel_ids(), 
rows, block, eos));
+                    
local_state._partitioner->get_channel_ids().get<uint32_t>(), rows, block, eos));
         } else {
             RETURN_IF_ERROR(channel_add_rows(
                     state, local_state.channel_shared_ptrs, 
local_state._partition_count,
-                    (uint32_t*)local_state._partitioner->get_channel_ids(), 
rows, block, eos));
+                    
local_state._partitioner->get_channel_ids().get<uint32_t>(), rows, block, eos));
         }
     } else if (_part_type == TPartitionType::TABLET_SINK_SHUFFLE_PARTITIONED) {
         // check out of limit
diff --git a/be/src/pipeline/exec/exchange_sink_operator.h 
b/be/src/pipeline/exec/exchange_sink_operator.h
index 9c40242cd03..f275365c0e8 100644
--- a/be/src/pipeline/exec/exchange_sink_operator.h
+++ b/be/src/pipeline/exec/exchange_sink_operator.h
@@ -76,8 +76,7 @@ private:
                 : _partitioner(partitioner) {}
 
         int get_partition(vectorized::Block* block, int position) {
-            uint32_t* partition_ids = 
(uint32_t*)_partitioner->get_channel_ids();
-            return partition_ids[position];
+            return _partitioner->get_channel_ids().get<uint32_t>()[position];
         }
 
     private:
diff --git a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp 
b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
index 78dcaf1e6c5..0f57a03fc64 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
@@ -535,7 +535,7 @@ Status 
PartitionedHashJoinProbeOperatorX::push(RuntimeState* state, vectorized::
     }
 
     std::vector<uint32_t> partition_indexes[_partition_count];
-    auto* channel_ids = 
reinterpret_cast<uint32_t*>(local_state._partitioner->get_channel_ids());
+    auto* channel_ids = 
local_state._partitioner->get_channel_ids().get<uint32_t>();
     for (uint32_t i = 0; i != rows; ++i) {
         partition_indexes[channel_ids[i]].emplace_back(i);
     }
diff --git a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp 
b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
index c9d61757461..d0ca832630e 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
@@ -145,7 +145,7 @@ Status 
PartitionedHashJoinSinkLocalState::_revoke_unpartitioned_block(RuntimeSta
         }
         auto& p = _parent->cast<PartitionedHashJoinSinkOperatorX>();
         SCOPED_TIMER(_partition_shuffle_timer);
-        auto* channel_ids = 
reinterpret_cast<uint32_t*>(_partitioner->get_channel_ids());
+        auto* channel_ids = _partitioner->get_channel_ids().get<uint32_t>();
 
         auto& partitioned_blocks = _shared_state->partitioned_build_blocks;
         std::vector<uint32_t> partition_indices;
@@ -293,7 +293,7 @@ Status 
PartitionedHashJoinSinkLocalState::_partition_block(RuntimeState* state,
 
     auto& p = _parent->cast<PartitionedHashJoinSinkOperatorX>();
     SCOPED_TIMER(_partition_shuffle_timer);
-    auto* channel_ids = 
reinterpret_cast<uint32_t*>(_partitioner->get_channel_ids());
+    auto* channel_ids = _partitioner->get_channel_ids().get<uint32_t>();
     std::vector<uint32_t> partition_indexes[p._partition_count];
     DCHECK_LT(begin, end);
     for (size_t i = begin; i != end; ++i) {
diff --git a/be/src/pipeline/pipeline_x/local_exchange/local_exchanger.cpp 
b/be/src/pipeline/pipeline_x/local_exchange/local_exchanger.cpp
index da395fefdd5..0837a1212b9 100644
--- a/be/src/pipeline/pipeline_x/local_exchange/local_exchanger.cpp
+++ b/be/src/pipeline/pipeline_x/local_exchange/local_exchanger.cpp
@@ -32,7 +32,7 @@ Status ShuffleExchanger::sink(RuntimeState* state, 
vectorized::Block* in_block,
     {
         SCOPED_TIMER(local_state._distribute_timer);
         RETURN_IF_ERROR(_split_rows(state,
-                                    (const 
uint32_t*)local_state._partitioner->get_channel_ids(),
+                                    
local_state._partitioner->get_channel_ids().get<uint32_t>(),
                                     in_block, eos, local_state));
     }
 
diff --git a/be/src/vec/runtime/partitioner.h b/be/src/vec/runtime/partitioner.h
index 66ed8809d7c..8d715a41285 100644
--- a/be/src/vec/runtime/partitioner.h
+++ b/be/src/vec/runtime/partitioner.h
@@ -26,6 +26,17 @@ class MemTracker;
 
 namespace vectorized {
 
+struct ChannelField {
+    const void* channel_id;
+    const uint32_t len;
+
+    template <typename T>
+    const T* get() const {
+        CHECK_EQ(sizeof(T), len) << " sizeof(T): " << sizeof(T) << " len: " << 
len;
+        return reinterpret_cast<const T*>(channel_id);
+    }
+};
+
 class PartitionerBase {
 public:
     PartitionerBase(size_t partition_count) : 
_partition_count(partition_count) {}
@@ -40,7 +51,7 @@ public:
     virtual Status do_partitioning(RuntimeState* state, Block* block,
                                    MemTracker* mem_tracker) const = 0;
 
-    virtual void* get_channel_ids() const = 0;
+    virtual ChannelField get_channel_ids() const = 0;
 
     virtual Status clone(RuntimeState* state, 
std::unique_ptr<PartitionerBase>& partitioner) = 0;
 
@@ -67,7 +78,9 @@ public:
     Status do_partitioning(RuntimeState* state, Block* block,
                            MemTracker* mem_tracker) const override;
 
-    void* get_channel_ids() const override { return _hash_vals.data(); }
+    ChannelField get_channel_ids() const override {
+        return {_hash_vals.data(), sizeof(HashValueType)};
+    }
 
 protected:
     Status _get_partition_column_result(Block* block, std::vector<int>& 
result) const {
diff --git a/be/src/vec/sink/vdata_stream_sender.cpp 
b/be/src/vec/sink/vdata_stream_sender.cpp
index ce6a5317fd4..69b7054f500 100644
--- a/be/src/vec/sink/vdata_stream_sender.cpp
+++ b/be/src/vec/sink/vdata_stream_sender.cpp
@@ -739,11 +739,11 @@ Status VDataStreamSender::send(RuntimeState* state, 
Block* block, bool eos) {
         }
         if (_part_type == TPartitionType::HASH_PARTITIONED) {
             RETURN_IF_ERROR(channel_add_rows(state, _channels, 
_partition_count,
-                                             
(uint64_t*)_partitioner->get_channel_ids(), rows,
+                                             
_partitioner->get_channel_ids().get<uint64_t>(), rows,
                                              block, _enable_pipeline_exec ? 
eos : false));
         } else {
             RETURN_IF_ERROR(channel_add_rows(state, _channel_shared_ptrs, 
_partition_count,
-                                             
(uint32_t*)_partitioner->get_channel_ids(), rows,
+                                             
_partitioner->get_channel_ids().get<uint32_t>(), rows,
                                              block, _enable_pipeline_exec ? 
eos : false));
         }
     } else if (_part_type == TPartitionType::TABLET_SINK_SHUFFLE_PARTITIONED) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to