This is an automated email from the ASF dual-hosted git repository.
Mryange pushed a commit to branch groupjoin
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/groupjoin by this push:
new a5bff24e114 be part
a5bff24e114 is described below
commit a5bff24e114ea3cbab90ac36285aa1b975473d2c
Author: Mryange <[email protected]>
AuthorDate: Mon Jun 22 10:29:16 2026 +0800
be part
---
be/src/exec/common/groupjoin_utils.h | 133 +++++++
be/src/exec/operator/groupjoin_build_sink.cpp | 265 +++++++++++++
be/src/exec/operator/groupjoin_build_sink.h | 104 +++++
be/src/exec/operator/groupjoin_operator_utils.cpp | 415 +++++++++++++++++++
be/src/exec/operator/groupjoin_operator_utils.h | 82 ++++
be/src/exec/operator/groupjoin_probe_operator.cpp | 251 ++++++++++++
be/src/exec/operator/groupjoin_probe_operator.h | 109 +++++
be/src/exec/operator/operator.cpp | 7 +
be/src/exec/pipeline/dependency.cpp | 5 +
be/src/exec/pipeline/dependency.h | 23 ++
be/src/exec/pipeline/pipeline_fragment_context.cpp | 22 ++
.../runtime_filter_producer_helper_groupjoin.h | 94 +++++
be/src/exprs/aggregate/aggregate_function.h | 36 ++
be/src/exprs/aggregate/aggregate_function_count.h | 51 +++
.../exprs/aggregate/aggregate_function_min_max.h | 13 +
.../exprs/aggregate/aggregate_function_null_v2.h | 63 +++
be/src/exprs/aggregate/aggregate_function_sum.h | 40 ++
be/src/exprs/vectorized_agg_fn.cpp | 8 +
be/src/exprs/vectorized_agg_fn.h | 4 +
.../operator/groupjoin_operator_utils_test.cpp | 154 ++++++++
.../aggregate/aggregate_function_repeat_test.cpp | 438 +++++++++++++++++++++
21 files changed, 2317 insertions(+)
diff --git a/be/src/exec/common/groupjoin_utils.h
b/be/src/exec/common/groupjoin_utils.h
new file mode 100644
index 00000000000..c257122aed5
--- /dev/null
+++ b/be/src/exec/common/groupjoin_utils.h
@@ -0,0 +1,133 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <variant>
+#include <vector>
+
+#include "common/exception.h"
+#include "exec/common/columns_hashing.h"
+#include "exec/common/hash_table/hash_crc32_return32.h"
+#include "exec/common/hash_table/hash_key_type.h"
+#include "exec/common/hash_table/hash_map_context.h"
+#include "exec/common/hash_table/hash_map_util.h"
+#include "exec/common/hash_table/ph_hash_map.h"
+#include "exec/common/hash_table/string_hash_map.h"
+#include "exprs/aggregate/aggregate_function.h"
+
+namespace doris {
+
+struct GroupJoinEntry {
+ uint64_t build_count = 0;
+ uint64_t probe_count = 0;
+ AggregateDataPtr agg_states = nullptr;
+};
+
+template <typename T>
+using GroupJoinData = PHHashMap<T, GroupJoinEntry*, HashCRC32<T>>;
+
+using GroupJoinDataWithStringKey = PHHashMap<StringRef, GroupJoinEntry*>;
+using GroupJoinDataWithShortStringKey = StringHashMap<GroupJoinEntry*>;
+
+template <class T>
+using GroupJoinPrimaryHashTableContext = MethodOneNumber<T, GroupJoinData<T>>;
+
+template <class Key>
+using GroupJoinFixedKeyHashTableContext = MethodKeysFixed<GroupJoinData<Key>>;
+
+using GroupJoinSerializedHashTableContext =
MethodSerialized<GroupJoinDataWithStringKey>;
+using GroupJoinMethodOneString =
MethodStringNoCache<GroupJoinDataWithShortStringKey>;
+
+using GroupJoinMethodVariants = std::variant<
+ std::monostate, GroupJoinSerializedHashTableContext,
+ GroupJoinPrimaryHashTableContext<UInt8>,
GroupJoinPrimaryHashTableContext<UInt16>,
+ GroupJoinPrimaryHashTableContext<UInt32>,
GroupJoinPrimaryHashTableContext<UInt64>,
+ GroupJoinPrimaryHashTableContext<UInt128>,
GroupJoinPrimaryHashTableContext<UInt256>,
+ GroupJoinFixedKeyHashTableContext<UInt64>,
GroupJoinFixedKeyHashTableContext<UInt72>,
+ GroupJoinFixedKeyHashTableContext<UInt96>,
GroupJoinFixedKeyHashTableContext<UInt104>,
+ GroupJoinFixedKeyHashTableContext<UInt128>,
GroupJoinFixedKeyHashTableContext<UInt136>,
+ GroupJoinFixedKeyHashTableContext<UInt256>, GroupJoinMethodOneString>;
+
+struct GroupJoinDataVariants
+ : public DataVariants<GroupJoinMethodVariants,
MethodSingleNullableColumn, MethodOneNumber,
+ DataWithNullKey> {
+ void init(const std::vector<DataTypePtr>& data_types, HashKeyType type) {
+ switch (type) {
+ case HashKeyType::serialized:
+ method_variant.emplace<GroupJoinSerializedHashTableContext>();
+ break;
+ case HashKeyType::int8_key:
+ method_variant.emplace<GroupJoinPrimaryHashTableContext<UInt8>>();
+ break;
+ case HashKeyType::int16_key:
+ method_variant.emplace<GroupJoinPrimaryHashTableContext<UInt16>>();
+ break;
+ case HashKeyType::int32_key:
+ method_variant.emplace<GroupJoinPrimaryHashTableContext<UInt32>>();
+ break;
+ case HashKeyType::int64_key:
+ method_variant.emplace<GroupJoinPrimaryHashTableContext<UInt64>>();
+ break;
+ case HashKeyType::int128_key:
+
method_variant.emplace<GroupJoinPrimaryHashTableContext<UInt128>>();
+ break;
+ case HashKeyType::int256_key:
+
method_variant.emplace<GroupJoinPrimaryHashTableContext<UInt256>>();
+ break;
+ case HashKeyType::string_key:
+ method_variant.emplace<GroupJoinMethodOneString>();
+ break;
+ case HashKeyType::fixed64:
+ method_variant.emplace<GroupJoinFixedKeyHashTableContext<UInt64>>(
+ get_key_sizes(data_types));
+ break;
+ case HashKeyType::fixed72:
+ method_variant.emplace<GroupJoinFixedKeyHashTableContext<UInt72>>(
+ get_key_sizes(data_types));
+ break;
+ case HashKeyType::fixed96:
+ method_variant.emplace<GroupJoinFixedKeyHashTableContext<UInt96>>(
+ get_key_sizes(data_types));
+ break;
+ case HashKeyType::fixed104:
+ method_variant.emplace<GroupJoinFixedKeyHashTableContext<UInt104>>(
+ get_key_sizes(data_types));
+ break;
+ case HashKeyType::fixed128:
+ method_variant.emplace<GroupJoinFixedKeyHashTableContext<UInt128>>(
+ get_key_sizes(data_types));
+ break;
+ case HashKeyType::fixed136:
+ method_variant.emplace<GroupJoinFixedKeyHashTableContext<UInt136>>(
+ get_key_sizes(data_types));
+ break;
+ case HashKeyType::fixed256:
+ method_variant.emplace<GroupJoinFixedKeyHashTableContext<UInt256>>(
+ get_key_sizes(data_types));
+ break;
+ default:
+ throw Exception(ErrorCode::INTERNAL_ERROR,
+ "GroupJoinDataVariants meet invalid key type,
type={}", type);
+ }
+ }
+};
+
+using GroupJoinDataVariantsUPtr = std::unique_ptr<GroupJoinDataVariants>;
+
+} // namespace doris
diff --git a/be/src/exec/operator/groupjoin_build_sink.cpp
b/be/src/exec/operator/groupjoin_build_sink.cpp
new file mode 100644
index 00000000000..1ab9c9b11a3
--- /dev/null
+++ b/be/src/exec/operator/groupjoin_build_sink.cpp
@@ -0,0 +1,265 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "exec/operator/groupjoin_build_sink.h"
+
+#include <variant>
+
+#include "common/cast_set.h"
+#include "core/data_type/data_type_nullable.h"
+#include "exec/common/hash_table/hash_map_util.h"
+#include "exec/common/util.hpp"
+#include "exec/operator/groupjoin_operator_utils.h"
+#include "exprs/vectorized_agg_fn.h"
+#include "exprs/vexpr.h"
+#include "runtime/descriptors.h"
+#include "runtime/runtime_state.h"
+
+namespace doris {
+
+GroupJoinBuildSinkLocalState::GroupJoinBuildSinkLocalState(DataSinkOperatorXBase*
parent,
+ RuntimeState* state)
+ : Base(parent, state) {
+ _finish_dependency = std::make_shared<CountedFinishDependency>(
+ parent->operator_id(), parent->node_id(), parent->get_name() +
"_FINISH_DEPENDENCY");
+}
+
+Status GroupJoinBuildSinkLocalState::init(RuntimeState* state,
LocalSinkStateInfo& info) {
+ RETURN_IF_ERROR(Base::init(state, info));
+ auto& p = _parent->cast<GroupJoinBuildSinkOperatorX>();
+ _build_expr_ctxs.resize(p._build_expr_ctxs.size());
+ for (size_t i = 0; i < _build_expr_ctxs.size(); ++i) {
+ RETURN_IF_ERROR(p._build_expr_ctxs[i]->clone(state,
_build_expr_ctxs[i]));
+ }
+ _aggregate_evaluators.reserve(p._aggregate_evaluators.size());
+ for (auto* evaluator : p._aggregate_evaluators) {
+ _aggregate_evaluators.push_back(evaluator->clone(state, p._pool));
+ }
+ RETURN_IF_ERROR(groupjoin::register_agg_state_layout(
+ _shared_state, p._aggregate_sides, p._sizes_of_aggregate_states,
+ p._aligns_of_aggregate_states, p._aggregate_indices,
p._aggregate_evaluators));
+
DCHECK(std::holds_alternative<std::monostate>(_shared_state->data_variants->method_variant));
+ std::vector<DataTypePtr> data_types;
+ data_types.reserve(_build_expr_ctxs.size());
+ for (const auto& ctx : _build_expr_ctxs) {
+ data_types.emplace_back(remove_nullable(ctx->root()->data_type()));
+ }
+
RETURN_IF_ERROR(init_hash_method<GroupJoinDataVariants>(_shared_state->data_variants.get(),
+ data_types, true));
+ if (!p._runtime_filter_descs.empty()) {
+ _runtime_filter_producer_helper =
std::make_shared<RuntimeFilterProducerHelperGroupJoin>();
+ RETURN_IF_ERROR(_runtime_filter_producer_helper->init(
+ state, _build_expr_ctxs, p._runtime_filter_descs,
p._child->row_desc()));
+ }
+ _dependency->set_ready();
+ return Status::OK();
+}
+
+Status GroupJoinBuildSinkLocalState::terminate(RuntimeState* state) {
+ if (_terminated) {
+ return Status::OK();
+ }
+ if (_runtime_filter_producer_helper) {
+ RETURN_IF_ERROR(_runtime_filter_producer_helper->skip_process(state));
+ }
+ return Base::terminate(state);
+}
+
+Status GroupJoinBuildSinkLocalState::close(RuntimeState* state, Status
exec_status) {
+ if (_closed) {
+ return Status::OK();
+ }
+ try {
+ if (!_terminated && _runtime_filter_producer_helper &&
!state->is_cancelled()) {
+ if (_runtime_filter_size_sent && exec_status.ok()) {
+
RETURN_IF_ERROR(_runtime_filter_producer_helper->build_and_publish(state));
+ } else {
+
RETURN_IF_ERROR(_runtime_filter_producer_helper->skip_process(state));
+ }
+ }
+ } catch (Exception& e) {
+ return Status::InternalError("GroupJoin runtime filter process meet
error: {}",
+ e.to_string());
+ }
+ if (_runtime_filter_producer_helper) {
+
_runtime_filter_producer_helper->collect_realtime_profile(custom_profile());
+ }
+ return Base::close(state, exec_status);
+}
+
+Status GroupJoinBuildSinkLocalState::_append_runtime_filter_columns(Block*
block) {
+ auto& p = _parent->cast<GroupJoinBuildSinkOperatorX>();
+ if (p._runtime_filter_descs.empty()) {
+ return Status::OK();
+ }
+ RETURN_IF_ERROR(_runtime_filter_producer_helper->append_block(block));
+ return Status::OK();
+}
+
+GroupJoinBuildSinkOperatorX::GroupJoinBuildSinkOperatorX(ObjectPool* pool, int
operator_id,
+ int dest_id, const
TPlanNode& tnode,
+ const DescriptorTbl&
descs)
+ : Base(operator_id, tnode, dest_id),
+ _join_distribution(tnode.group_join_node.__isset.dist_type
+ ? tnode.group_join_node.dist_type
+ : TJoinDistributionType::NONE),
+ _pool(pool),
+ _partition_exprs(tnode.__isset.distribute_expr_lists ?
tnode.distribute_expr_lists[0]
+ :
std::vector<TExpr> {}),
+ _runtime_filter_descs(tnode.runtime_filters) {}
+
+Status GroupJoinBuildSinkOperatorX::init(const TPlanNode& tnode, RuntimeState*
state) {
+ RETURN_IF_ERROR(Base::init(tnode, state));
+ RETURN_IF_ERROR(groupjoin::validate_group_join_node(tnode));
+ for (const auto& eq_join_conjunct :
tnode.group_join_node.eq_join_conjuncts) {
+ VExprContextSPtr build_ctx;
+ RETURN_IF_ERROR(VExpr::create_expr_tree(eq_join_conjunct.right,
build_ctx));
+ {
+ VExprContextSPtr probe_ctx;
+ RETURN_IF_ERROR(VExpr::create_expr_tree(eq_join_conjunct.left,
probe_ctx));
+ auto build_side_expr_type = build_ctx->root()->data_type();
+ auto probe_side_expr_type = probe_ctx->root()->data_type();
+ if (!make_nullable(build_side_expr_type)
+ ->equals(*make_nullable(probe_side_expr_type))) {
+ return Status::InternalError(
+ "GroupJoin build side type {}, not match probe side
type {}, node={}",
+ build_side_expr_type->get_name(),
probe_side_expr_type->get_name(),
+ debug_string(0));
+ }
+ }
+ _build_expr_ctxs.push_back(build_ctx);
+ }
+ const auto& group_join_node = tnode.group_join_node;
+ _aggregate_sides.resize(group_join_node.aggregate_functions.size());
+
_sizes_of_aggregate_states.assign(group_join_node.aggregate_functions.size(),
0);
+
_aligns_of_aggregate_states.assign(group_join_node.aggregate_functions.size(),
1);
+ _output_tuple_id = group_join_node.output_tuple_id;
+
+ TSortInfo dummy;
+ for (int i = 0; i < group_join_node.aggregate_functions.size(); ++i) {
+ const auto& aggregate_function =
group_join_node.aggregate_functions[i];
+ _aggregate_sides[i] = aggregate_function.input_side;
+ if (aggregate_function.input_side != TGroupJoinAggSide::BUILD) {
+ continue;
+ }
+ AggFnEvaluator* evaluator = nullptr;
+ RETURN_IF_ERROR(AggFnEvaluator::create(_pool,
aggregate_function.aggregate_function, dummy,
+ false, false, &evaluator));
+ _aggregate_evaluators.push_back(evaluator);
+ _aggregate_indices.push_back(i);
+ }
+ return Status::OK();
+}
+
+Status GroupJoinBuildSinkOperatorX::prepare(RuntimeState* state) {
+ RETURN_IF_ERROR(Base::prepare(state));
+ RETURN_IF_ERROR(VExpr::prepare(_build_expr_ctxs, state,
_child->row_desc()));
+ RETURN_IF_ERROR(VExpr::open(_build_expr_ctxs, state));
+ _output_tuple_desc =
state->desc_tbl().get_tuple_descriptor(_output_tuple_id);
+ DCHECK(_output_tuple_desc != nullptr);
+ const auto key_size = _build_expr_ctxs.size();
+ for (size_t i = 0; i < _aggregate_evaluators.size(); ++i) {
+ const auto agg_idx = _aggregate_indices[i];
+ SlotDescriptor* output_slot_desc =
_output_tuple_desc->slots()[key_size + agg_idx];
+ RETURN_IF_ERROR(_aggregate_evaluators[i]->prepare(state,
_child->row_desc(),
+ output_slot_desc,
output_slot_desc));
+ _aggregate_evaluators[i]->set_version(state->be_exec_version());
+ _sizes_of_aggregate_states[agg_idx] =
_aggregate_evaluators[i]->function()->size_of_data();
+ _aligns_of_aggregate_states[agg_idx] =
+ _aggregate_evaluators[i]->function()->align_of_data();
+ }
+ for (auto* evaluator : _aggregate_evaluators) {
+ RETURN_IF_ERROR(evaluator->open(state));
+ }
+ return Status::OK();
+}
+
+Status GroupJoinBuildSinkOperatorX::sink_impl(RuntimeState* state, Block*
in_block, bool eos) {
+ auto& local_state = get_local_state(state);
+ SCOPED_TIMER(local_state.exec_time_counter());
+ COUNTER_UPDATE(local_state.rows_input_counter(),
static_cast<int64_t>(in_block->rows()));
+ if (in_block->rows() > 0) {
+ const auto rows = cast_set<uint32_t>(in_block->rows());
+ RETURN_IF_ERROR(groupjoin::do_evaluate(*in_block,
local_state._build_expr_ctxs,
+
local_state._key_columns_holder));
+ RETURN_IF_ERROR(groupjoin::extract_key_columns(
+ in_block->rows(), local_state._key_columns_holder,
+ local_state._build_key_not_nullable_columns,
local_state._null_map_column));
+ const uint8_t* null_map = local_state._null_map_column
+ ?
local_state._null_map_column->get_data().data()
+ : nullptr;
+ local_state._places.resize(rows);
+ RETURN_IF_ERROR(groupjoin::add_build_counts_by_key(
+ local_state._shared_state, *local_state._shared_state->arena,
+ local_state._build_key_not_nullable_columns, rows, null_map,
_aggregate_indices,
+ local_state._places.data()));
+ for (size_t i = 0; i < local_state._aggregate_evaluators.size(); ++i) {
+ const auto offset =
+
local_state._shared_state->offsets_of_aggregate_states[_aggregate_indices[i]];
+ // If there is no nullable join key, every build row has a valid
place because build
+ // always creates/fetches a hash-table entry. If null_map exists,
rows with NULL join
+ // keys are skipped for normal inner equal join and keep
places[row] as nullptr.
+ if (null_map == nullptr) {
+
RETURN_IF_ERROR(local_state._aggregate_evaluators[i]->execute_batch_add(
+ in_block, offset, local_state._places.data(),
+ *local_state._shared_state->arena));
+ } else {
+
RETURN_IF_ERROR(local_state._aggregate_evaluators[i]->execute_batch_add_selected(
+ in_block, offset, local_state._places.data(),
+ *local_state._shared_state->arena));
+ }
+ }
+ RETURN_IF_ERROR(local_state._append_runtime_filter_columns(in_block));
+ }
+ if (eos) {
+ if (local_state._runtime_filter_producer_helper) {
+
RETURN_IF_ERROR(local_state._runtime_filter_producer_helper->send_filter_size(
+ state,
local_state._runtime_filter_producer_helper->build_rows(),
+ local_state._finish_dependency));
+ local_state._runtime_filter_size_sent = true;
+ }
+ local_state._dependency->set_ready_to_read();
+ }
+ return Status::OK();
+}
+
+DataDistribution GroupJoinBuildSinkOperatorX::required_data_distribution(
+ RuntimeState* state) const {
+ // Keep the same distribution rule as hash join's non-broadcast path.
GroupJoin currently
+ // supports only inner equi join, so broadcast/null-aware special branches
are not needed.
+ return _join_distribution == TJoinDistributionType::BUCKET_SHUFFLE ||
+ _join_distribution ==
TJoinDistributionType::COLOCATE
+ ? DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE,
_partition_exprs)
+ : DataDistribution(ExchangeType::HASH_SHUFFLE,
_partition_exprs);
+}
+
+bool GroupJoinBuildSinkOperatorX::is_shuffled_operator() const {
+ return _join_distribution == TJoinDistributionType::PARTITIONED ||
+ _join_distribution == TJoinDistributionType::BUCKET_SHUFFLE ||
+ _join_distribution == TJoinDistributionType::COLOCATE;
+}
+
+bool GroupJoinBuildSinkOperatorX::is_colocated_operator() const {
+ return _join_distribution == TJoinDistributionType::BUCKET_SHUFFLE ||
+ _join_distribution == TJoinDistributionType::COLOCATE;
+}
+
+bool GroupJoinBuildSinkOperatorX::followed_by_shuffled_operator() const {
+ return (is_shuffled_operator() && !is_colocated_operator()) ||
_followed_by_shuffled_operator;
+}
+
+} // namespace doris
diff --git a/be/src/exec/operator/groupjoin_build_sink.h
b/be/src/exec/operator/groupjoin_build_sink.h
new file mode 100644
index 00000000000..9eb2a4579ea
--- /dev/null
+++ b/be/src/exec/operator/groupjoin_build_sink.h
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <gen_cpp/PlanNodes_types.h>
+
+#include <memory>
+#include <vector>
+
+#include "core/block/block.h"
+#include "core/column/column_vector.h"
+#include "exec/common/groupjoin_utils.h"
+#include "exec/operator/operator.h"
+#include "exec/pipeline/dependency.h"
+#include "exec/runtime_filter/runtime_filter_producer_helper_groupjoin.h"
+
+namespace doris {
+
+class AggFnEvaluator;
+class GroupJoinBuildSinkOperatorX;
+
+class GroupJoinBuildSinkLocalState final : public
PipelineXSinkLocalState<GroupJoinSharedState> {
+public:
+ ENABLE_FACTORY_CREATOR(GroupJoinBuildSinkLocalState);
+ using Base = PipelineXSinkLocalState<GroupJoinSharedState>;
+ using Parent = GroupJoinBuildSinkOperatorX;
+
+ GroupJoinBuildSinkLocalState(DataSinkOperatorXBase* parent, RuntimeState*
state);
+ ~GroupJoinBuildSinkLocalState() override = default;
+
+ Status init(RuntimeState* state, LocalSinkStateInfo& info) override;
+ Status terminate(RuntimeState* state) override;
+ Status close(RuntimeState* state, Status exec_status) override;
+
+ Dependency* finishdependency() override { return _finish_dependency.get();
}
+
+private:
+ friend class GroupJoinBuildSinkOperatorX;
+
+ Status _append_runtime_filter_columns(Block* block);
+
+ VExprContextSPtrs _build_expr_ctxs;
+ std::vector<AggFnEvaluator*> _aggregate_evaluators;
+ ColumnRawPtrs _build_key_not_nullable_columns;
+ std::vector<ColumnPtr> _key_columns_holder;
+ std::vector<AggregateDataPtr> _places;
+ ColumnUInt8::MutablePtr _null_map_column;
+ std::shared_ptr<RuntimeFilterProducerHelperGroupJoin>
_runtime_filter_producer_helper;
+ std::shared_ptr<CountedFinishDependency> _finish_dependency;
+ bool _runtime_filter_size_sent = false;
+};
+
+class GroupJoinBuildSinkOperatorX final : public
DataSinkOperatorX<GroupJoinBuildSinkLocalState> {
+public:
+ using Base = DataSinkOperatorX<GroupJoinBuildSinkLocalState>;
+
+ GroupJoinBuildSinkOperatorX(ObjectPool* pool, int operator_id, int dest_id,
+ const TPlanNode& tnode, const DescriptorTbl&
descs);
+
+ Status init(const TDataSink& tsink) override {
+ return Status::InternalError("{} should not init with TDataSink",
_name);
+ }
+ Status init(const TPlanNode& tnode, RuntimeState* state) override;
+ Status prepare(RuntimeState* state) override;
+ Status sink_impl(RuntimeState* state, Block* in_block, bool eos) override;
+
+ DataDistribution required_data_distribution(RuntimeState* state) const
override;
+ bool is_shuffled_operator() const override;
+ bool is_colocated_operator() const override;
+ bool followed_by_shuffled_operator() const override;
+
+private:
+ friend class GroupJoinBuildSinkLocalState;
+
+ const TJoinDistributionType::type _join_distribution;
+ ObjectPool* _pool = nullptr;
+ std::vector<TExpr> _partition_exprs;
+ VExprContextSPtrs _build_expr_ctxs;
+ std::vector<AggFnEvaluator*> _aggregate_evaluators;
+ std::vector<int> _aggregate_indices;
+ std::vector<TGroupJoinAggSide::type> _aggregate_sides;
+ const std::vector<TRuntimeFilterDesc> _runtime_filter_descs;
+ Sizes _sizes_of_aggregate_states;
+ Sizes _aligns_of_aggregate_states;
+ TupleId _output_tuple_id;
+ TupleDescriptor* _output_tuple_desc = nullptr;
+};
+
+} // namespace doris
diff --git a/be/src/exec/operator/groupjoin_operator_utils.cpp
b/be/src/exec/operator/groupjoin_operator_utils.cpp
new file mode 100644
index 00000000000..53253ac74d3
--- /dev/null
+++ b/be/src/exec/operator/groupjoin_operator_utils.cpp
@@ -0,0 +1,415 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "exec/operator/groupjoin_operator_utils.h"
+
+#include <gen_cpp/Opcodes_types.h>
+
+#include <algorithm>
+#include <new>
+#include <variant>
+
+#include "common/cast_set.h"
+#include "common/exception.h"
+#include "core/arena.h"
+#include "core/column/column_nullable.h"
+#include "exec/common/hash_table/hash_map_context.h"
+#include "exec/common/hash_table/hash_map_util.h"
+#include "exec/common/template_helpers.hpp"
+#include "exec/common/util.hpp"
+#include "exec/operator/groupjoin_probe_operator.h"
+#include "exec/pipeline/dependency.h"
+#include "exprs/vectorized_agg_fn.h"
+#include "exprs/vexpr.h"
+
+namespace doris::groupjoin {
+
+Status validate_group_join_node(const TPlanNode& tnode) {
+ if (!tnode.__isset.group_join_node) {
+ return Status::InternalError("GroupJoin node is not set");
+ }
+
+ const auto& group_join_node = tnode.group_join_node;
+ if (group_join_node.join_op != TJoinOp::INNER_JOIN) {
+ return Status::InternalError("GroupJoin only supports inner join now");
+ }
+ if (group_join_node.__isset.dist_type &&
+ group_join_node.dist_type == TJoinDistributionType::BROADCAST) {
+ return Status::InternalError("GroupJoin does not support broadcast
join now");
+ }
+ if (group_join_node.agg_output_mode !=
TGroupJoinAggOutputMode::FINAL_RESULT) {
+ return Status::InternalError(
+ "GroupJoin only supports final-result aggregate output mode
now: {}",
+ group_join_node.agg_output_mode);
+ }
+ if (group_join_node.aggregate_functions.empty()) {
+ return Status::InternalError("GroupJoin requires at least one
aggregate function");
+ }
+ for (const auto& eq_join_conjunct : group_join_node.eq_join_conjuncts) {
+ if (eq_join_conjunct.__isset.opcode &&
+ eq_join_conjunct.opcode == TExprOpcode::EQ_FOR_NULL) {
+ return Status::InternalError("GroupJoin does not support null-safe
equal join now");
+ }
+ }
+ return Status::OK();
+}
+
+Status do_evaluate(const Block& block, VExprContextSPtrs& exprs,
+ std::vector<ColumnPtr>& key_columns_holder) {
+ key_columns_holder.resize(exprs.size());
+ for (size_t i = 0; i < exprs.size(); ++i) {
+ RETURN_IF_ERROR(exprs[i]->execute(&block, key_columns_holder[i]));
+ key_columns_holder[i] =
key_columns_holder[i]->convert_to_full_column_if_const();
+ }
+ return Status::OK();
+}
+
+Status extract_key_columns(size_t rows, const std::vector<ColumnPtr>&
key_columns_holder,
+ ColumnRawPtrs& key_not_nullable_columns,
+ ColumnUInt8::MutablePtr& null_map_column) {
+ key_not_nullable_columns.resize(key_columns_holder.size());
+ null_map_column.reset();
+
+ for (size_t i = 0; i < key_columns_holder.size(); ++i) {
+ const auto* column = key_columns_holder[i].get();
+ if (const auto* nullable =
check_and_get_column<ColumnNullable>(*column); nullable) {
+ const auto& col_nested = nullable->get_nested_column();
+ const auto& col_nullmap = nullable->get_null_map_data();
+ if (!null_map_column) {
+ null_map_column = ColumnUInt8::create();
+ null_map_column->get_data().assign(rows, uint8_t {0});
+ }
+ VectorizedUtils::update_null_map(null_map_column->get_data(),
col_nullmap);
+ key_not_nullable_columns[i] = &col_nested;
+ } else {
+ key_not_nullable_columns[i] = column;
+ }
+ }
+ return Status::OK();
+}
+
+Status register_agg_state_layout(GroupJoinSharedState* shared_state,
+ const std::vector<TGroupJoinAggSide::type>&
aggregate_sides,
+ const Sizes& sizes_of_aggregate_states,
+ const Sizes& aligns_of_aggregate_states,
+ const std::vector<int>& aggregate_indices,
+ const std::vector<AggFnEvaluator*>&
aggregate_evaluators) {
+ if (shared_state->aggregate_sides.empty()) {
+ shared_state->aggregate_sides = aggregate_sides;
+ shared_state->sizes_of_aggregate_states.assign(aggregate_sides.size(),
0);
+
shared_state->aligns_of_aggregate_states.assign(aggregate_sides.size(), 1);
+
shared_state->offsets_of_aggregate_states.assign(aggregate_sides.size(), 0);
+ shared_state->aggregate_evaluators.assign(aggregate_sides.size(),
nullptr);
+ } else if (shared_state->aggregate_sides != aggregate_sides) {
+ return Status::InternalError("GroupJoin aggregate sides are
inconsistent");
+ }
+
+ if (sizes_of_aggregate_states.size() !=
shared_state->aggregate_sides.size() ||
+ aligns_of_aggregate_states.size() !=
shared_state->aggregate_sides.size()) {
+ return Status::InternalError("GroupJoin aggregate state layout size is
inconsistent");
+ }
+ if (aggregate_indices.size() != aggregate_evaluators.size()) {
+ return Status::InternalError("GroupJoin aggregate evaluator index size
is inconsistent");
+ }
+
+ for (size_t i = 0; i < sizes_of_aggregate_states.size(); ++i) {
+ if (sizes_of_aggregate_states[i] == 0) {
+ continue;
+ }
+ if (shared_state->sizes_of_aggregate_states[i] != 0 &&
+ shared_state->sizes_of_aggregate_states[i] !=
sizes_of_aggregate_states[i]) {
+ return Status::InternalError("GroupJoin aggregate state size is
inconsistent");
+ }
+ if (shared_state->aligns_of_aggregate_states[i] != 1 &&
+ shared_state->aligns_of_aggregate_states[i] !=
aligns_of_aggregate_states[i]) {
+ return Status::InternalError("GroupJoin aggregate state align is
inconsistent");
+ }
+ shared_state->sizes_of_aggregate_states[i] =
sizes_of_aggregate_states[i];
+ shared_state->aligns_of_aggregate_states[i] =
aligns_of_aggregate_states[i];
+ }
+ for (size_t i = 0; i < aggregate_indices.size(); ++i) {
+ const auto agg_idx = aggregate_indices[i];
+ DCHECK_GE(agg_idx, 0);
+ DCHECK_LT(agg_idx, shared_state->aggregate_evaluators.size());
+ if (shared_state->aggregate_evaluators[agg_idx] != nullptr &&
+ shared_state->aggregate_evaluators[agg_idx] !=
aggregate_evaluators[i]) {
+ return Status::InternalError("GroupJoin aggregate evaluator is
inconsistent");
+ }
+ shared_state->aggregate_evaluators[agg_idx] = aggregate_evaluators[i];
+ }
+
+ bool ready = true;
+ for (size_t size : shared_state->sizes_of_aggregate_states) {
+ ready &= size != 0;
+ }
+ for (auto* evaluator : shared_state->aggregate_evaluators) {
+ ready &= evaluator != nullptr;
+ }
+ if (!ready || shared_state->agg_layout_ready) {
+ return Status::OK();
+ }
+
+ shared_state->total_size_of_aggregate_states = 0;
+ shared_state->align_aggregate_states = 1;
+ for (size_t i = 0; i < shared_state->sizes_of_aggregate_states.size();
++i) {
+ shared_state->offsets_of_aggregate_states[i] =
shared_state->total_size_of_aggregate_states;
+ const auto align = shared_state->aligns_of_aggregate_states[i];
+ if ((align & (align - 1)) != 0) {
+ return Status::RuntimeError("Logical error: GroupJoin
align_of_data is not 2^N");
+ }
+ shared_state->align_aggregate_states =
+ std::max(shared_state->align_aggregate_states, align);
+ shared_state->total_size_of_aggregate_states +=
shared_state->sizes_of_aggregate_states[i];
+ if (i + 1 < shared_state->sizes_of_aggregate_states.size()) {
+ const auto next_align = shared_state->aligns_of_aggregate_states[i
+ 1];
+ shared_state->total_size_of_aggregate_states =
+ (shared_state->total_size_of_aggregate_states + next_align
- 1) / next_align *
+ next_align;
+ }
+ }
+ shared_state->agg_layout_ready = true;
+ return Status::OK();
+}
+
+void create_all_agg_states(GroupJoinSharedState* shared_state,
AggregateDataPtr data) {
+ for (size_t i = 0; i < shared_state->aggregate_evaluators.size(); ++i) {
+ try {
+ shared_state->aggregate_evaluators[i]->create(
+ data + shared_state->offsets_of_aggregate_states[i]);
+ } catch (...) {
+ for (size_t j = 0; j < i; ++j) {
+ shared_state->aggregate_evaluators[j]->destroy(
+ data + shared_state->offsets_of_aggregate_states[j]);
+ }
+ throw;
+ }
+ }
+}
+
+void destroy_all_agg_states(GroupJoinSharedState* shared_state,
AggregateDataPtr data) {
+ if (data == nullptr) {
+ return;
+ }
+ for (size_t i = 0; i < shared_state->aggregate_evaluators.size(); ++i) {
+ shared_state->aggregate_evaluators[i]->destroy(
+ data + shared_state->offsets_of_aggregate_states[i]);
+ }
+}
+
+void ensure_entry_agg_states(GroupJoinEntry* entry, GroupJoinSharedState*
shared_state,
+ Arena& arena) {
+ DCHECK(!shared_state->aggregate_evaluators.empty())
+ << "ensure_entry_agg_states is called only when this side has
aggregate evaluators";
+ DCHECK(shared_state->agg_layout_ready)
+ << "aggregate state layout must be registered before creating
entry states";
+ if (entry->agg_states != nullptr) {
+ return;
+ }
+ auto* agg_states =
arena.aligned_alloc(shared_state->total_size_of_aggregate_states,
+
shared_state->align_aggregate_states);
+ create_all_agg_states(shared_state, agg_states);
+ entry->agg_states = agg_states;
+}
+
+Status add_build_counts_by_key(GroupJoinSharedState* shared_state, Arena&
arena,
+ ColumnRawPtrs& key_not_nullable_columns,
uint32_t num_rows,
+ const uint8_t* null_map, const
std::vector<int>& aggregate_indices,
+ AggregateDataPtr* places) {
+ std::fill(places, places + num_rows, nullptr);
+ return std::visit(
+ Overload {[&](std::monostate&) -> Status {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ },
+ [&](auto& hash_method) -> Status {
+ using HashMethodType =
std::decay_t<decltype(hash_method)>;
+ using State = typename HashMethodType::State;
+ State state(key_not_nullable_columns);
+
hash_method.init_serialized_keys(key_not_nullable_columns, num_rows,
+ null_map);
+
+ auto creator = [&](const auto& ctor, auto& key,
auto& origin) {
+ HashMethodType::try_presis_key_and_origin(key,
origin, arena);
+ auto* mapped = new
(arena.alloc<GroupJoinEntry>()) GroupJoinEntry();
+ ctor(key, mapped);
+ };
+ auto creator_for_null_key = [](auto&) {
+ throw doris::Exception(
+ ErrorCode::INTERNAL_ERROR,
+ "GroupJoin key columns should not
contain nullable columns");
+ };
+
+ auto result_handler = [&](uint32_t row, auto&
mapped) {
+ ++mapped->build_count;
+ // Only create aggregate states when the build
side
+ // has aggregate functions to update. Otherwise
this
+ // side only maintains row counts.
+ if (!aggregate_indices.empty()) {
+ ensure_entry_agg_states(mapped,
shared_state, arena);
+ }
+ places[row] = mapped->agg_states;
+ };
+ if (null_map == nullptr) {
+ lazy_emplace_batch(hash_method, state, num_rows,
creator,
+ creator_for_null_key,
result_handler);
+ } else {
+ for (uint32_t row = 0; row < num_rows; ++row) {
+ // For normal inner equal join, any row with
a NULL join key can
+ // never match the probe side. Skip it
before lazy_emplace so the
+ // build hash map does not contain NULL-key
entries.
+ if (null_map[row]) {
+ continue;
+ }
+ auto& mapped =
*hash_method.lazy_emplace(state, row, creator,
+
creator_for_null_key);
+ result_handler(row, mapped);
+ }
+ }
+ return Status::OK();
+ }},
+ shared_state->data_variants->method_variant);
+}
+
+Status update_probe_counts(GroupJoinSharedState* shared_state, Arena& arena,
+ ColumnRawPtrs& key_not_nullable_columns, uint32_t
num_rows,
+ const uint8_t* null_map, const std::vector<int>&
aggregate_indices,
+ AggregateDataPtr* places, int64_t& matched_rows,
+ uint32_t& matched_probe_rows) {
+ std::fill(places, places + num_rows, nullptr);
+ matched_rows = 0;
+ matched_probe_rows = 0;
+ return std::visit(
+ Overload {[&](std::monostate&) -> Status {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ },
+ [&](auto& hash_method) -> Status {
+ using HashMethodType =
std::decay_t<decltype(hash_method)>;
+ using State = typename HashMethodType::State;
+ State state(key_not_nullable_columns);
+
hash_method.init_serialized_keys(key_not_nullable_columns, num_rows,
+ null_map);
+ find_batch(hash_method, state, num_rows,
[&](uint32_t row, auto& result) {
+ // For normal inner equal join, any row with a
NULL join key can never
+ // match the build side.
+ if ((null_map != nullptr && null_map[row]) ||
!result.is_found()) {
+ return;
+ }
+ auto* mapped = result.get_mapped();
+ ++mapped->probe_count;
+ ++matched_probe_rows;
+ matched_rows +=
cast_set<int64_t>(mapped->build_count);
+ // Only create aggregate states when the probe
side has aggregate
+ // functions to update. Otherwise this side only
maintains row counts.
+ if (!aggregate_indices.empty()) {
+ ensure_entry_agg_states(mapped,
shared_state, arena);
+ }
+ places[row] = mapped->agg_states;
+ });
+ return Status::OK();
+ }},
+ shared_state->data_variants->method_variant);
+}
+
+Status drain_groupjoin_result(GroupJoinSharedState* shared_state, size_t
batch_size,
+ GroupJoinProbeLocalState& local_state,
MutableColumns& key_columns,
+ MutableColumns& value_columns, bool& output_eos)
{
+ local_state._output_arena.clear();
+ const size_t agg_size = shared_state->aggregate_evaluators.size();
+ return std::visit(
+ Overload {[&](std::monostate&) -> Status {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ },
+ [&](auto& hash_method) -> Status {
+ if (!shared_state->drain_inited) {
+ hash_method.init_iterator();
+ shared_state->drain_inited = true;
+ }
+ using HashMethodType =
std::decay_t<decltype(hash_method)>;
+ using KeyType = typename HashMethodType::Key;
+ std::vector<KeyType> keys(batch_size);
+ if (local_state._values.size() < batch_size) {
+ local_state._values.resize(batch_size);
+ }
+ if (local_state._entries.size() < batch_size) {
+ local_state._entries.resize(batch_size);
+ }
+ if (local_state._repeats.size() < batch_size) {
+ local_state._repeats.resize(batch_size);
+ }
+
+ uint32_t num_rows = 0;
+ auto& iter = hash_method.begin;
+ while (iter != hash_method.end && num_rows <
batch_size) {
+ auto* entry = iter.get_second();
+ // Inner join only outputs groups that have
valid rows on both sides.
+ // Build-side NULL rows may create an entry
before being skipped, so
+ // probe_count alone is not sufficient.
+ if (entry->build_count > 0 && entry->probe_count
> 0) {
+ keys[num_rows] = iter.get_first();
+ local_state._entries[num_rows] = entry;
+ local_state._values[num_rows] =
entry->agg_states;
+ ++num_rows;
+ }
+ ++iter;
+ }
+
+ hash_method.insert_keys_into_columns(keys,
key_columns, num_rows);
+ for (size_t i = 0; i < agg_size; ++i) {
+ for (uint32_t row = 0; row < num_rows; ++row) {
+ local_state._repeats[row] =
+ shared_state->aggregate_sides[i] ==
+
TGroupJoinAggSide::BUILD
+ ?
local_state._entries[row]->probe_count
+ :
local_state._entries[row]->build_count;
+ }
+
shared_state->aggregate_evaluators[i]->insert_result_info_repeat_vec(
+ local_state._values,
+
shared_state->offsets_of_aggregate_states[i],
+ local_state._repeats,
value_columns[i].get(), num_rows,
+ local_state._output_arena);
+ }
+
+ output_eos = iter == hash_method.end;
+ return Status::OK();
+ }},
+ shared_state->data_variants->method_variant);
+}
+
+void destroy_entry_agg_states(GroupJoinSharedState* shared_state,
GroupJoinEntry* entry) {
+ if (entry == nullptr || entry->agg_states == nullptr) {
+ return;
+ }
+ destroy_all_agg_states(shared_state, entry->agg_states);
+ entry->agg_states = nullptr;
+}
+
+void destroy_agg_states(GroupJoinSharedState* shared_state) {
+ if (shared_state == nullptr || shared_state->data_variants == nullptr) {
+ return;
+ }
+ std::visit(Overload {[&](std::monostate&) -> void {},
+ [&](auto& hash_method) -> void {
+ auto& hash_table = *hash_method.hash_table;
+ auto iter = hash_table.begin();
+ while (iter != hash_table.end()) {
+ destroy_entry_agg_states(shared_state,
iter.get_second());
+ ++iter;
+ }
+ }},
+ shared_state->data_variants->method_variant);
+}
+
+} // namespace doris::groupjoin
diff --git a/be/src/exec/operator/groupjoin_operator_utils.h
b/be/src/exec/operator/groupjoin_operator_utils.h
new file mode 100644
index 00000000000..4463ceb3bad
--- /dev/null
+++ b/be/src/exec/operator/groupjoin_operator_utils.h
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <gen_cpp/PlanNodes_types.h>
+
+#include <vector>
+
+#include "common/status.h"
+#include "core/block/block.h"
+#include "core/column/column_vector.h"
+#include "exec/common/groupjoin_utils.h"
+#include "exprs/vexpr_fwd.h"
+
+namespace doris {
+
+class Arena;
+class AggFnEvaluator;
+class GroupJoinProbeLocalState;
+struct GroupJoinSharedState;
+namespace groupjoin {
+
+Status validate_group_join_node(const TPlanNode& tnode);
+
+Status do_evaluate(const Block& block, VExprContextSPtrs& exprs,
+ std::vector<ColumnPtr>& key_columns_holder);
+
+// Hash map batch APIs take ColumnRawPtrs as key input. This helper converts
evaluated
+// key expression columns to raw column pointers and builds the combined null
map.
+// GroupJoin currently supports only normal equal join keys. Therefore
nullable keys are
+// split into nested columns and the combined null map is used to skip rows
with NULL keys:
+// for a normal inner equal join, any row containing a NULL join key can never
match.
+// Null-safe equal needs a different path that keeps nullable keys encoded in
the hash key.
+Status extract_key_columns(size_t rows, const std::vector<ColumnPtr>&
key_columns_holder,
+ ColumnRawPtrs& key_not_nullable_columns,
+ ColumnUInt8::MutablePtr& null_map_column);
+
+Status register_agg_state_layout(GroupJoinSharedState* shared_state,
+ const std::vector<TGroupJoinAggSide::type>&
aggregate_sides,
+ const Sizes& sizes_of_aggregate_states,
+ const Sizes& aligns_of_aggregate_states,
+ const std::vector<int>& aggregate_indices,
+ const std::vector<AggFnEvaluator*>&
aggregate_evaluators);
+
+Status add_build_counts_by_key(GroupJoinSharedState* shared_state, Arena&
arena,
+ ColumnRawPtrs& key_not_nullable_columns,
uint32_t num_rows,
+ const uint8_t* null_map, const
std::vector<int>& aggregate_indices,
+ AggregateDataPtr* places);
+
+Status update_probe_counts(GroupJoinSharedState* shared_state, Arena& arena,
+ ColumnRawPtrs& key_not_nullable_columns, uint32_t
num_rows,
+ const uint8_t* null_map, const std::vector<int>&
aggregate_indices,
+ AggregateDataPtr* places, int64_t& matched_rows,
+ uint32_t& matched_probe_rows);
+
+void create_all_agg_states(GroupJoinSharedState* shared_state,
AggregateDataPtr data);
+
+void destroy_all_agg_states(GroupJoinSharedState* shared_state,
AggregateDataPtr data);
+
+Status drain_groupjoin_result(GroupJoinSharedState* shared_state, size_t
batch_size,
+ GroupJoinProbeLocalState& local_state,
MutableColumns& key_columns,
+ MutableColumns& value_columns, bool& output_eos);
+
+void destroy_agg_states(GroupJoinSharedState* shared_state);
+
+} // namespace groupjoin
+} // namespace doris
diff --git a/be/src/exec/operator/groupjoin_probe_operator.cpp
b/be/src/exec/operator/groupjoin_probe_operator.cpp
new file mode 100644
index 00000000000..9a80da4bcb6
--- /dev/null
+++ b/be/src/exec/operator/groupjoin_probe_operator.cpp
@@ -0,0 +1,251 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "exec/operator/groupjoin_probe_operator.h"
+
+#include <algorithm>
+
+#include "common/cast_set.h"
+#include "core/column/column_nullable.h"
+#include "core/data_type/data_type_nullable.h"
+#include "exec/common/util.hpp"
+#include "exec/operator/groupjoin_operator_utils.h"
+#include "exprs/vectorized_agg_fn.h"
+#include "exprs/vexpr.h"
+#include "runtime/descriptors.h"
+#include "runtime/runtime_state.h"
+
+namespace doris {
+
+Status GroupJoinProbeLocalState::init(RuntimeState* state, LocalStateInfo&
info) {
+ RETURN_IF_ERROR(Base::init(state, info));
+ auto& p = _parent->cast<GroupJoinProbeOperatorX>();
+ _probe_expr_ctxs.resize(p._probe_expr_ctxs.size());
+ for (size_t i = 0; i < _probe_expr_ctxs.size(); ++i) {
+ RETURN_IF_ERROR(p._probe_expr_ctxs[i]->clone(state,
_probe_expr_ctxs[i]));
+ }
+ _aggregate_evaluators.reserve(p._aggregate_evaluators.size());
+ for (auto* evaluator : p._aggregate_evaluators) {
+ _aggregate_evaluators.push_back(evaluator->clone(state, p._pool));
+ }
+ RETURN_IF_ERROR(groupjoin::register_agg_state_layout(
+ _shared_state, p._aggregate_sides, p._sizes_of_aggregate_states,
+ p._aligns_of_aggregate_states, p._aggregate_indices,
p._aggregate_evaluators));
+ return Status::OK();
+}
+
+GroupJoinProbeOperatorX::GroupJoinProbeOperatorX(ObjectPool* pool, const
TPlanNode& tnode,
+ int operator_id, const
DescriptorTbl& descs)
+ : Base(pool, tnode, operator_id, descs),
+ _join_distribution(tnode.group_join_node.__isset.dist_type
+ ? tnode.group_join_node.dist_type
+ : TJoinDistributionType::NONE),
+ _partition_exprs(tnode.__isset.distribute_expr_lists ?
tnode.distribute_expr_lists[0]
+ :
std::vector<TExpr> {}) {
+ _op_name = "GROUP_JOIN_PROBE_OPERATOR";
+}
+
+Status GroupJoinProbeOperatorX::init(const TPlanNode& tnode, RuntimeState*
state) {
+ RETURN_IF_ERROR(Base::init(tnode, state));
+ RETURN_IF_ERROR(groupjoin::validate_group_join_node(tnode));
+ for (const auto& eq_join_conjunct :
tnode.group_join_node.eq_join_conjuncts) {
+ VExprContextSPtr ctx;
+ RETURN_IF_ERROR(VExpr::create_expr_tree(eq_join_conjunct.left, ctx));
+ _probe_expr_ctxs.push_back(ctx);
+ }
+ const auto& group_join_node = tnode.group_join_node;
+ _aggregate_sides.resize(group_join_node.aggregate_functions.size());
+
_sizes_of_aggregate_states.assign(group_join_node.aggregate_functions.size(),
0);
+
_aligns_of_aggregate_states.assign(group_join_node.aggregate_functions.size(),
1);
+ _output_tuple_id = group_join_node.output_tuple_id;
+
+ TSortInfo dummy;
+ for (int i = 0; i < group_join_node.aggregate_functions.size(); ++i) {
+ const auto& aggregate_function =
group_join_node.aggregate_functions[i];
+ _aggregate_sides[i] = aggregate_function.input_side;
+ if (aggregate_function.input_side != TGroupJoinAggSide::PROBE) {
+ continue;
+ }
+ AggFnEvaluator* evaluator = nullptr;
+ RETURN_IF_ERROR(AggFnEvaluator::create(_pool,
aggregate_function.aggregate_function, dummy,
+ false, false, &evaluator));
+ _aggregate_evaluators.push_back(evaluator);
+ _aggregate_indices.push_back(i);
+ }
+ return Status::OK();
+}
+
+Status GroupJoinProbeOperatorX::prepare(RuntimeState* state) {
+ RETURN_IF_ERROR(Base::prepare(state));
+ RETURN_IF_ERROR(VExpr::prepare(_probe_expr_ctxs, state,
_child->row_desc()));
+ RETURN_IF_ERROR(VExpr::open(_probe_expr_ctxs, state));
+ _output_tuple_desc =
state->desc_tbl().get_tuple_descriptor(_output_tuple_id);
+ DCHECK(_output_tuple_desc != nullptr);
+ const auto key_size = _probe_expr_ctxs.size();
+ for (size_t i = 0; i < key_size; ++i) {
+ auto nullable_output = _output_tuple_desc->slots()[i]->is_nullable();
+ auto nullable_input = _probe_expr_ctxs[i]->root()->is_nullable();
+ if (nullable_output != nullable_input) {
+ DCHECK(nullable_output);
+ _make_nullable_keys.emplace_back(i);
+ }
+ }
+ for (size_t i = 0; i < _aggregate_evaluators.size(); ++i) {
+ const auto agg_idx = _aggregate_indices[i];
+ SlotDescriptor* output_slot_desc =
_output_tuple_desc->slots()[key_size + agg_idx];
+ RETURN_IF_ERROR(_aggregate_evaluators[i]->prepare(state,
_child->row_desc(),
+ output_slot_desc,
output_slot_desc));
+ _aggregate_evaluators[i]->set_version(state->be_exec_version());
+ _sizes_of_aggregate_states[agg_idx] =
_aggregate_evaluators[i]->function()->size_of_data();
+ _aligns_of_aggregate_states[agg_idx] =
+ _aggregate_evaluators[i]->function()->align_of_data();
+ }
+ for (auto* evaluator : _aggregate_evaluators) {
+ RETURN_IF_ERROR(evaluator->open(state));
+ }
+ return Status::OK();
+}
+
+Status GroupJoinProbeOperatorX::push(RuntimeState* state, Block* input_block,
bool eos) const {
+ auto& local_state = get_local_state(state);
+ if (input_block->rows() > 0) {
+ const auto rows = cast_set<uint32_t>(input_block->rows());
+ RETURN_IF_ERROR(groupjoin::do_evaluate(*input_block,
local_state._probe_expr_ctxs,
+
local_state._key_columns_holder));
+ RETURN_IF_ERROR(groupjoin::extract_key_columns(
+ input_block->rows(), local_state._key_columns_holder,
+ local_state._probe_key_not_nullable_columns,
local_state._null_map_column));
+ const uint8_t* null_map = local_state._null_map_column
+ ?
local_state._null_map_column->get_data().data()
+ : nullptr;
+ local_state._places.resize(rows);
+ int64_t matched_rows = 0;
+ uint32_t matched_probe_rows = 0;
+ RETURN_IF_ERROR(groupjoin::update_probe_counts(
+ local_state._shared_state, *local_state._shared_state->arena,
+ local_state._probe_key_not_nullable_columns, rows, null_map,
_aggregate_indices,
+ local_state._places.data(), matched_rows, matched_probe_rows));
+ local_state._shared_state->total_match_count += matched_rows;
+ for (size_t i = 0; i < local_state._aggregate_evaluators.size(); ++i) {
+ const auto offset =
+
local_state._shared_state->offsets_of_aggregate_states[_aggregate_indices[i]];
+ // Probe rows can have nullptr places not only because of NULL
join keys, but also
+ // because an inner join probe key may not exist in the build hash
table.
+ if (matched_probe_rows == rows) {
+
RETURN_IF_ERROR(local_state._aggregate_evaluators[i]->execute_batch_add(
+ input_block, offset, local_state._places.data(),
+ *local_state._shared_state->arena));
+ } else {
+
RETURN_IF_ERROR(local_state._aggregate_evaluators[i]->execute_batch_add_selected(
+ input_block, offset, local_state._places.data(),
+ *local_state._shared_state->arena));
+ }
+ }
+ }
+ if (eos) {
+ local_state._shared_state->probe_eos = true;
+ }
+ return Status::OK();
+}
+
+Status GroupJoinProbeOperatorX::pull(RuntimeState* state, Block* output_block,
bool* eos) const {
+ auto& local_state = get_local_state(state);
+ auto* shared_state = local_state._shared_state;
+ if (shared_state->result_emitted) {
+ output_block->clear_column_data();
+ *eos = true;
+ return Status::OK();
+ }
+
+ auto columns_with_schema =
VectorizedUtils::create_columns_with_type_and_name(row_desc());
+ const size_t key_size = local_state._probe_expr_ctxs.size();
+ const size_t agg_size = shared_state->aggregate_evaluators.size();
+ DCHECK_EQ(columns_with_schema.size(), key_size + agg_size);
+
+ MutableColumns key_columns;
+ key_columns.reserve(key_size);
+ for (size_t i = 0; i < key_size; ++i) {
+ const auto output_key_need_nullable =
+ std::ranges::find(_make_nullable_keys, i) !=
_make_nullable_keys.end();
+ key_columns.emplace_back((output_key_need_nullable
+ ?
remove_nullable(columns_with_schema[i].type)
+ : columns_with_schema[i].type)
+ ->create_column());
+ }
+
+ MutableColumns value_columns;
+ value_columns.reserve(agg_size);
+ for (size_t i = 0; i < agg_size; ++i) {
+ DCHECK(shared_state->aggregate_evaluators[i] != nullptr);
+ value_columns.emplace_back(columns_with_schema[key_size +
i].type->create_column());
+ }
+
+ bool output_eos = false;
+ RETURN_IF_ERROR(groupjoin::drain_groupjoin_result(shared_state,
state->batch_size(),
+ local_state,
key_columns, value_columns,
+ output_eos));
+
+ for (size_t i = 0; i < key_size; ++i) {
+ columns_with_schema[i].column = std::move(key_columns[i]);
+ }
+ for (size_t i = 0; i < agg_size; ++i) {
+ columns_with_schema[key_size + i].column = std::move(value_columns[i]);
+ }
+ *output_block = Block(std::move(columns_with_schema));
+ if (output_block->rows() != 0) {
+ for (auto cid : _make_nullable_keys) {
+ output_block->get_by_position(cid).column =
+ make_nullable(output_block->get_by_position(cid).column);
+ output_block->get_by_position(cid).type =
+ make_nullable(output_block->get_by_position(cid).type);
+ }
+ }
+ shared_state->result_emitted = output_eos;
+ *eos = output_eos;
+ return Status::OK();
+}
+
+bool GroupJoinProbeOperatorX::need_more_input_data(RuntimeState* state) const {
+ auto& local_state = get_local_state(state);
+ return !local_state._shared_state->probe_eos;
+}
+
+DataDistribution
GroupJoinProbeOperatorX::required_data_distribution(RuntimeState* state) const {
+ // Keep the same distribution rule as hash join's non-broadcast path.
GroupJoin currently
+ // supports only inner equi join, so broadcast/null-aware special branches
are not needed.
+ return _join_distribution == TJoinDistributionType::BUCKET_SHUFFLE ||
+ _join_distribution ==
TJoinDistributionType::COLOCATE
+ ? DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE,
_partition_exprs)
+ : DataDistribution(ExchangeType::HASH_SHUFFLE,
_partition_exprs);
+}
+
+bool GroupJoinProbeOperatorX::is_shuffled_operator() const {
+ return _join_distribution == TJoinDistributionType::PARTITIONED ||
+ _join_distribution == TJoinDistributionType::BUCKET_SHUFFLE ||
+ _join_distribution == TJoinDistributionType::COLOCATE;
+}
+
+bool GroupJoinProbeOperatorX::is_colocated_operator() const {
+ return _join_distribution == TJoinDistributionType::BUCKET_SHUFFLE ||
+ _join_distribution == TJoinDistributionType::COLOCATE;
+}
+
+bool GroupJoinProbeOperatorX::followed_by_shuffled_operator() const {
+ return (is_shuffled_operator() && !is_colocated_operator()) ||
_followed_by_shuffled_operator;
+}
+
+} // namespace doris
diff --git a/be/src/exec/operator/groupjoin_probe_operator.h
b/be/src/exec/operator/groupjoin_probe_operator.h
new file mode 100644
index 00000000000..78a1c0156f7
--- /dev/null
+++ b/be/src/exec/operator/groupjoin_probe_operator.h
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <gen_cpp/PlanNodes_types.h>
+
+#include <memory>
+#include <vector>
+
+#include "core/arena.h"
+#include "core/column/column_vector.h"
+#include "exec/common/groupjoin_utils.h"
+#include "exec/operator/operator.h"
+#include "exec/pipeline/dependency.h"
+
+namespace doris {
+
+class AggFnEvaluator;
+class GroupJoinProbeLocalState;
+class GroupJoinProbeOperatorX;
+namespace groupjoin {
+Status drain_groupjoin_result(GroupJoinSharedState* shared_state, size_t
batch_size,
+ GroupJoinProbeLocalState& local_state,
MutableColumns& key_columns,
+ MutableColumns& value_columns, bool& output_eos);
+}
+
+class GroupJoinProbeLocalState final : public
PipelineXLocalState<GroupJoinSharedState> {
+public:
+ ENABLE_FACTORY_CREATOR(GroupJoinProbeLocalState);
+ using Base = PipelineXLocalState<GroupJoinSharedState>;
+ using Parent = GroupJoinProbeOperatorX;
+
+ GroupJoinProbeLocalState(RuntimeState* state, OperatorXBase* parent) :
Base(state, parent) {}
+ ~GroupJoinProbeLocalState() override = default;
+
+ Status init(RuntimeState* state, LocalStateInfo& info) override;
+
+private:
+ friend class GroupJoinProbeOperatorX;
+ friend Status groupjoin::drain_groupjoin_result(GroupJoinSharedState*,
size_t,
+ GroupJoinProbeLocalState&,
MutableColumns&,
+ MutableColumns&, bool&);
+ template <typename LocalStateType>
+ friend class StatefulOperatorX;
+
+ std::unique_ptr<Block> _child_block = Block::create_unique();
+ bool _child_eos = false;
+ VExprContextSPtrs _probe_expr_ctxs;
+ std::vector<AggFnEvaluator*> _aggregate_evaluators;
+ ColumnRawPtrs _probe_key_not_nullable_columns;
+ std::vector<ColumnPtr> _key_columns_holder;
+ std::vector<AggregateDataPtr> _places;
+ std::vector<AggregateDataPtr> _values;
+ std::vector<GroupJoinEntry*> _entries;
+ std::vector<uint64_t> _repeats;
+ Arena _output_arena;
+ ColumnUInt8::MutablePtr _null_map_column;
+};
+
+class GroupJoinProbeOperatorX final : public
StatefulOperatorX<GroupJoinProbeLocalState> {
+public:
+ using Base = StatefulOperatorX<GroupJoinProbeLocalState>;
+
+ GroupJoinProbeOperatorX(ObjectPool* pool, const TPlanNode& tnode, int
operator_id,
+ const DescriptorTbl& descs);
+
+ Status init(const TPlanNode& tnode, RuntimeState* state) override;
+ Status prepare(RuntimeState* state) override;
+ Status push(RuntimeState* state, Block* input_block, bool eos) const
override;
+ Status pull(RuntimeState* state, Block* output_block, bool* eos) const
override;
+ bool need_more_input_data(RuntimeState* state) const override;
+
+ DataDistribution required_data_distribution(RuntimeState* state) const
override;
+ bool is_shuffled_operator() const override;
+ bool is_colocated_operator() const override;
+ bool followed_by_shuffled_operator() const override;
+
+private:
+ friend class GroupJoinProbeLocalState;
+
+ const TJoinDistributionType::type _join_distribution;
+ std::vector<TExpr> _partition_exprs;
+ VExprContextSPtrs _probe_expr_ctxs;
+ std::vector<AggFnEvaluator*> _aggregate_evaluators;
+ std::vector<int> _aggregate_indices;
+ std::vector<TGroupJoinAggSide::type> _aggregate_sides;
+ Sizes _sizes_of_aggregate_states;
+ Sizes _aligns_of_aggregate_states;
+ std::vector<size_t> _make_nullable_keys;
+ TupleId _output_tuple_id;
+ TupleDescriptor* _output_tuple_desc = nullptr;
+};
+
+} // namespace doris
diff --git a/be/src/exec/operator/operator.cpp
b/be/src/exec/operator/operator.cpp
index 60f199b2b49..9358f7e70a7 100644
--- a/be/src/exec/operator/operator.cpp
+++ b/be/src/exec/operator/operator.cpp
@@ -40,6 +40,8 @@
#include "exec/operator/file_scan_operator.h"
#include "exec/operator/group_commit_block_sink_operator.h"
#include "exec/operator/group_commit_scan_operator.h"
+#include "exec/operator/groupjoin_build_sink.h"
+#include "exec/operator/groupjoin_probe_operator.h"
#include "exec/operator/hashjoin_build_sink.h"
#include "exec/operator/hashjoin_probe_operator.h"
#include "exec/operator/hive_table_sink_operator.h"
@@ -848,6 +850,7 @@ DECLARE_OPERATOR(LocalExchangeSinkLocalState)
DECLARE_OPERATOR(AggSinkLocalState)
DECLARE_OPERATOR(BucketedAggSinkLocalState)
DECLARE_OPERATOR(PartitionedAggSinkLocalState)
+DECLARE_OPERATOR(GroupJoinBuildSinkLocalState)
DECLARE_OPERATOR(ExchangeSinkLocalState)
DECLARE_OPERATOR(NestedLoopJoinBuildSinkLocalState)
DECLARE_OPERATOR(UnionSinkLocalState)
@@ -898,6 +901,7 @@ DECLARE_OPERATOR(PartitionedHashJoinProbeLocalState)
DECLARE_OPERATOR(CacheSourceLocalState)
DECLARE_OPERATOR(RecCTESourceLocalState)
DECLARE_OPERATOR(RecCTEScanLocalState)
+DECLARE_OPERATOR(GroupJoinProbeLocalState)
#ifdef BE_TEST
DECLARE_OPERATOR(MockLocalState)
@@ -916,6 +920,7 @@ template class StatefulOperatorX<StreamingAggLocalState>;
template class StatefulOperatorX<DistinctStreamingAggLocalState>;
template class StatefulOperatorX<NestedLoopJoinProbeLocalState>;
template class StatefulOperatorX<TableFunctionLocalState>;
+template class StatefulOperatorX<GroupJoinProbeLocalState>;
template class PipelineXSinkLocalState<HashJoinSharedState>;
template class PipelineXSinkLocalState<PartitionedHashJoinSharedState>;
@@ -925,6 +930,7 @@ template class
PipelineXSinkLocalState<NestedLoopJoinSharedState>;
template class PipelineXSinkLocalState<AnalyticSharedState>;
template class PipelineXSinkLocalState<AggSharedState>;
template class PipelineXSinkLocalState<BucketedAggSharedState>;
+template class PipelineXSinkLocalState<GroupJoinSharedState>;
template class PipelineXSinkLocalState<PartitionedAggSharedState>;
template class PipelineXSinkLocalState<FakeSharedState>;
template class PipelineXSinkLocalState<UnionSharedState>;
@@ -944,6 +950,7 @@ template class
PipelineXLocalState<NestedLoopJoinSharedState>;
template class PipelineXLocalState<AnalyticSharedState>;
template class PipelineXLocalState<AggSharedState>;
template class PipelineXLocalState<BucketedAggSharedState>;
+template class PipelineXLocalState<GroupJoinSharedState>;
template class PipelineXLocalState<PartitionedAggSharedState>;
template class PipelineXLocalState<FakeSharedState>;
template class PipelineXLocalState<UnionSharedState>;
diff --git a/be/src/exec/pipeline/dependency.cpp
b/be/src/exec/pipeline/dependency.cpp
index 97fd6d037d4..b15980d3c19 100644
--- a/be/src/exec/pipeline/dependency.cpp
+++ b/be/src/exec/pipeline/dependency.cpp
@@ -21,6 +21,7 @@
#include <mutex>
#include "common/logging.h"
+#include "exec/operator/groupjoin_operator_utils.h"
#include "exec/operator/multi_cast_data_streamer.h"
#include "exec/pipeline/pipeline_fragment_context.h"
#include "exec/pipeline/pipeline_task.h"
@@ -364,6 +365,10 @@ void
BucketedAggSharedState::_destroy_agg_status(AggregateDataPtr data) {
LocalExchangeSharedState::~LocalExchangeSharedState() = default;
+GroupJoinSharedState::~GroupJoinSharedState() {
+ groupjoin::destroy_agg_states(this);
+}
+
Status SetSharedState::update_build_not_ignore_null(const VExprContextSPtrs&
ctxs) {
if (ctxs.size() > build_not_ignore_null.size()) {
return Status::InternalError("build_not_ignore_null not initialized");
diff --git a/be/src/exec/pipeline/dependency.h
b/be/src/exec/pipeline/dependency.h
index b08dd186710..e65d8a37866 100644
--- a/be/src/exec/pipeline/dependency.h
+++ b/be/src/exec/pipeline/dependency.h
@@ -23,6 +23,7 @@
#endif
#include <concurrentqueue.h>
+#include <gen_cpp/PlanNodes_types.h>
#include <gen_cpp/internal_service.pb.h>
#include <sqltypes.h>
@@ -39,6 +40,7 @@
#include "core/block/block.h"
#include "core/types.h"
#include "exec/common/agg_utils.h"
+#include "exec/common/groupjoin_utils.h"
#include "exec/common/join_utils.h"
#include "exec/common/set_utils.h"
#include "exec/operator/data_queue.h"
@@ -756,6 +758,27 @@ struct HashJoinSharedState : public JoinSharedState {
std::vector<uint32_t> asof_build_row_to_bucket;
};
+struct GroupJoinSharedState : public BasicSharedState {
+ ENABLE_FACTORY_CREATOR(GroupJoinSharedState)
+ GroupJoinSharedState() { data_variants =
std::make_unique<GroupJoinDataVariants>(); }
+ ~GroupJoinSharedState() override;
+
+ bool probe_eos = false;
+ bool result_emitted = false;
+ GroupJoinDataVariantsUPtr data_variants = nullptr;
+ std::shared_ptr<Arena> arena = std::make_shared<Arena>();
+ int64_t total_match_count = 0;
+ bool drain_inited = false;
+ std::vector<TGroupJoinAggSide::type> aggregate_sides;
+ std::vector<AggFnEvaluator*> aggregate_evaluators;
+ Sizes offsets_of_aggregate_states;
+ Sizes sizes_of_aggregate_states;
+ Sizes aligns_of_aggregate_states;
+ size_t total_size_of_aggregate_states = 0;
+ size_t align_aggregate_states = 1;
+ bool agg_layout_ready = false;
+};
+
struct PartitionedHashJoinSharedState
: public HashJoinSharedState,
public std::enable_shared_from_this<PartitionedHashJoinSharedState> {
diff --git a/be/src/exec/pipeline/pipeline_fragment_context.cpp
b/be/src/exec/pipeline/pipeline_fragment_context.cpp
index 615d67e5c9f..e7c14770f09 100644
--- a/be/src/exec/pipeline/pipeline_fragment_context.cpp
+++ b/be/src/exec/pipeline/pipeline_fragment_context.cpp
@@ -67,6 +67,8 @@
#include "exec/operator/file_scan_operator.h"
#include "exec/operator/group_commit_block_sink_operator.h"
#include "exec/operator/group_commit_scan_operator.h"
+#include "exec/operator/groupjoin_build_sink.h"
+#include "exec/operator/groupjoin_probe_operator.h"
#include "exec/operator/hashjoin_build_sink.h"
#include "exec/operator/hashjoin_probe_operator.h"
#include "exec/operator/hive_table_sink_operator.h"
@@ -1517,6 +1519,26 @@ Status
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
}
break;
}
+ case TPlanNodeType::GROUP_JOIN_NODE: {
+ op = std::make_shared<GroupJoinProbeOperatorX>(pool, tnode,
next_operator_id(), descs);
+ RETURN_IF_ERROR(cur_pipe->add_operator(op, _parallel_instances));
+
+ const auto downstream_pipeline_id = cur_pipe->id();
+ if (!_dag.contains(downstream_pipeline_id)) {
+ _dag.insert({downstream_pipeline_id, {}});
+ }
+ PipelinePtr build_side_pipe = add_pipeline(cur_pipe);
+ _dag[downstream_pipeline_id].push_back(build_side_pipe->id());
+
+ sink_ops.push_back(std::make_shared<GroupJoinBuildSinkOperatorX>(
+ pool, next_sink_operator_id(), op->operator_id(), tnode,
descs));
+ RETURN_IF_ERROR(build_side_pipe->set_sink(sink_ops.back()));
+ RETURN_IF_ERROR(build_side_pipe->sink()->init(tnode,
_runtime_state.get()));
+
+ _pipeline_parent_map.push(op->node_id(), cur_pipe);
+ _pipeline_parent_map.push(op->node_id(), build_side_pipe);
+ break;
+ }
case TPlanNodeType::HASH_JOIN_NODE: {
const auto is_broadcast_join =
tnode.hash_join_node.__isset.is_broadcast_join &&
tnode.hash_join_node.is_broadcast_join;
diff --git
a/be/src/exec/runtime_filter/runtime_filter_producer_helper_groupjoin.h
b/be/src/exec/runtime_filter/runtime_filter_producer_helper_groupjoin.h
new file mode 100644
index 00000000000..2080d2220a4
--- /dev/null
+++ b/be/src/exec/runtime_filter/runtime_filter_producer_helper_groupjoin.h
@@ -0,0 +1,94 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+
+#include "common/status.h"
+#include "core/block/block.h"
+#include "exec/runtime_filter/runtime_filter_producer_helper.h"
+#include "exprs/vexpr.h"
+#include "exprs/vexpr_context.h"
+#include "runtime/runtime_state.h"
+
+namespace doris {
+
+// This helper is used by GroupJoin build sink. GroupJoin does not keep a full
build block, so
+// this helper caches only runtime-filter source columns and replays them
after RF size is ready.
+class RuntimeFilterProducerHelperGroupJoin final : public
RuntimeFilterProducerHelper {
+public:
+ ~RuntimeFilterProducerHelperGroupJoin() override = default;
+
+ RuntimeFilterProducerHelperGroupJoin() : RuntimeFilterProducerHelper(true,
false) {}
+
+ Status append_block(Block* block) {
+ if (_skip_runtime_filters_process) {
+ return Status::OK();
+ }
+ if (block->rows() == 0) {
+ return Status::OK();
+ }
+
+ std::vector<ColumnPtr> filter_columns;
+ filter_columns.reserve(_filter_expr_contexts.size());
+ for (auto& ctx : _filter_expr_contexts) {
+ ColumnPtr column;
+ RETURN_IF_ERROR(ctx->execute(block, column));
+ column = column->convert_to_full_column_if_const();
+ filter_columns.emplace_back(std::move(column));
+ }
+
+ _build_rows += block->rows();
+ _cached_filter_columns.emplace_back(std::move(filter_columns));
+ return Status::OK();
+ }
+
+ uint64_t build_rows() const { return _build_rows; }
+
+ Status build_and_publish(RuntimeState* state) {
+ if (_skip_runtime_filters_process) {
+ return Status::OK();
+ }
+
+ RETURN_IF_ERROR(_init_filters(state, _build_rows));
+ for (const auto& filter_columns : _cached_filter_columns) {
+ RETURN_IF_ERROR(_insert_columns(filter_columns));
+ }
+
+ for (const auto& filter : _producers) {
+
filter->set_wrapper_state_and_ready_to_publish(RuntimeFilterWrapper::State::READY);
+ }
+ RETURN_IF_ERROR(_publish(state));
+ return Status::OK();
+ }
+
+private:
+ Status _insert_columns(const std::vector<ColumnPtr>& filter_columns) {
+ DCHECK_EQ(filter_columns.size(), _producers.size());
+ SCOPED_TIMER(_runtime_filter_compute_timer.get());
+ for (size_t i = 0; i < _producers.size(); ++i) {
+ RETURN_IF_ERROR(_producers[i]->insert(filter_columns[i], 0));
+ }
+ return Status::OK();
+ }
+
+ std::vector<std::vector<ColumnPtr>> _cached_filter_columns;
+ uint64_t _build_rows = 0;
+};
+
+} // namespace doris
diff --git a/be/src/exprs/aggregate/aggregate_function.h
b/be/src/exprs/aggregate/aggregate_function.h
index 6a0c364b7cb..5dffe765d6a 100644
--- a/be/src/exprs/aggregate/aggregate_function.h
+++ b/be/src/exprs/aggregate/aggregate_function.h
@@ -25,6 +25,7 @@
#include "common/exception.h"
#include "common/status.h"
+#include "core/arena.h"
#include "core/assert_cast.h"
#include "core/column/column_complex.h"
#include "core/column/column_fixed_length_object.h"
@@ -194,6 +195,14 @@ public:
const size_t offset, IColumn& to,
const size_t num_rows) const = 0;
+ virtual void insert_result_into_repeat(ConstAggregateDataPtr place,
uint64_t repeat,
+ IColumn& to, Arena& arena) const =
0;
+
+ virtual void insert_result_into_repeat_vec(const
std::vector<AggregateDataPtr>& places,
+ const size_t offset,
+ const std::vector<uint64_t>&
repeats, IColumn& to,
+ const size_t num_rows, Arena&
arena) const = 0;
+
/** Contains a loop with calls to "add" function. You can collect
arguments into array "places"
* and do a single call to "add_batch" for devirtualization and inlining.
* This function distributes inputs row to their corresponding
aggregation states,
@@ -441,6 +450,33 @@ public:
}
}
+ void insert_result_into_repeat(ConstAggregateDataPtr place, uint64_t
repeat, IColumn& to,
+ Arena& arena) const override {
+ const Derived* derived = assert_cast<const Derived*>(this);
+ auto* output_state = arena.aligned_alloc(derived->size_of_data(),
derived->align_of_data());
+ derived->create(output_state);
+ try {
+ for (uint64_t i = 0; i < repeat; ++i) {
+ derived->merge(output_state, place, arena);
+ }
+ derived->insert_result_into(output_state, to);
+ } catch (...) {
+ derived->destroy(output_state);
+ throw;
+ }
+ derived->destroy(output_state);
+ }
+
+ void insert_result_into_repeat_vec(const std::vector<AggregateDataPtr>&
places,
+ const size_t offset, const
std::vector<uint64_t>& repeats,
+ IColumn& to, const size_t num_rows,
+ Arena& arena) const override {
+ const Derived* derived = assert_cast<const Derived*>(this);
+ for (size_t i = 0; i != num_rows; ++i) {
+ derived->insert_result_into_repeat(places[i] + offset, repeats[i],
to, arena);
+ }
+ }
+
void serialize_vec(const std::vector<AggregateDataPtr>& places, size_t
offset,
BufferWritable& buf, const size_t num_rows) const
override {
const Derived* derived = assert_cast<const Derived*>(this);
diff --git a/be/src/exprs/aggregate/aggregate_function_count.h
b/be/src/exprs/aggregate/aggregate_function_count.h
index 1a5b25a18e9..b406de27ff1 100644
--- a/be/src/exprs/aggregate/aggregate_function_count.h
+++ b/be/src/exprs/aggregate/aggregate_function_count.h
@@ -95,6 +95,23 @@ public:
assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
}
+ void insert_result_into_repeat(ConstAggregateDataPtr place, uint64_t
repeat, IColumn& to,
+ Arena&) const override {
+ auto& column = assert_cast<ColumnInt64&>(to);
+ column.get_data().push_back(data(place).count * repeat);
+ }
+
+ void insert_result_into_repeat_vec(const std::vector<AggregateDataPtr>&
places,
+ const size_t offset, const
std::vector<uint64_t>& repeats,
+ IColumn& to, const size_t num_rows,
Arena&) const override {
+ auto& column = assert_cast<ColumnInt64&>(to);
+ auto& column_data = column.get_data();
+ column_data.reserve(column_data.size() + num_rows);
+ for (size_t i = 0; i != num_rows; ++i) {
+ column_data.push_back(data(places[i] + offset).count * repeats[i]);
+ }
+ }
+
void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
MutableColumnPtr& dst, const size_t num_rows)
const override {
auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
@@ -249,6 +266,40 @@ public:
}
}
+ void insert_result_into_repeat(ConstAggregateDataPtr place, uint64_t
repeat, IColumn& to,
+ Arena&) const override {
+ if (is_column_nullable(to)) {
+ auto& nullable_column = assert_cast<ColumnNullable&>(to);
+ nullable_column.get_null_map_data().push_back(0);
+ assert_cast<ColumnInt64&>(nullable_column.get_nested_column())
+ .get_data()
+ .push_back(data(place).count * repeat);
+ } else {
+
assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count * repeat);
+ }
+ }
+
+ void insert_result_into_repeat_vec(const std::vector<AggregateDataPtr>&
places,
+ const size_t offset, const
std::vector<uint64_t>& repeats,
+ IColumn& to, const size_t num_rows,
Arena&) const override {
+ if (is_column_nullable(to)) {
+ auto& nullable_column = assert_cast<ColumnNullable&>(to);
+ nullable_column.get_null_map_column().insert_many_vals(0,
num_rows);
+ auto& nested_data =
+
assert_cast<ColumnInt64&>(nullable_column.get_nested_column()).get_data();
+ nested_data.reserve(nested_data.size() + num_rows);
+ for (size_t i = 0; i != num_rows; ++i) {
+ nested_data.push_back(data(places[i] + offset).count *
repeats[i]);
+ }
+ } else {
+ auto& column_data = assert_cast<ColumnInt64&>(to).get_data();
+ column_data.reserve(column_data.size() + num_rows);
+ for (size_t i = 0; i != num_rows; ++i) {
+ column_data.push_back(data(places[i] + offset).count *
repeats[i]);
+ }
+ }
+ }
+
void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
MutableColumnPtr& dst, const size_t num_rows)
const override {
auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
diff --git a/be/src/exprs/aggregate/aggregate_function_min_max.h
b/be/src/exprs/aggregate/aggregate_function_min_max.h
index 1baf0cf1fd1..5e315578252 100644
--- a/be/src/exprs/aggregate/aggregate_function_min_max.h
+++ b/be/src/exprs/aggregate/aggregate_function_min_max.h
@@ -810,6 +810,19 @@ public:
this->data(place).insert_result_into(to);
}
+ void insert_result_into_repeat(ConstAggregateDataPtr place, uint64_t,
IColumn& to,
+ Arena&) const override {
+ this->data(place).insert_result_into(to);
+ }
+
+ void insert_result_into_repeat_vec(const std::vector<AggregateDataPtr>&
places,
+ const size_t offset, const
std::vector<uint64_t>&,
+ IColumn& to, const size_t num_rows,
Arena&) const override {
+ for (size_t i = 0; i != num_rows; ++i) {
+ this->data(places[i] + offset).insert_result_into(to);
+ }
+ }
+
void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
MutableColumnPtr& dst, const size_t num_rows)
const override {
if constexpr (Data::IsFixedLength) {
diff --git a/be/src/exprs/aggregate/aggregate_function_null_v2.h
b/be/src/exprs/aggregate/aggregate_function_null_v2.h
index ef508db89fe..85d6b5038ba 100644
--- a/be/src/exprs/aggregate/aggregate_function_null_v2.h
+++ b/be/src/exprs/aggregate/aggregate_function_null_v2.h
@@ -380,6 +380,69 @@ public:
nested_function->insert_result_into(nested_place(place), to);
}
}
+
+ void insert_result_into_repeat(ConstAggregateDataPtr place, uint64_t
repeat, IColumn& to,
+ Arena& arena) const override {
+ if constexpr (result_is_nullable) {
+ auto& to_concrete = assert_cast<ColumnNullable&>(to);
+ auto& nested_column = to_concrete.get_nested_column();
+ auto& null_map = to_concrete.get_null_map_data();
+ if (get_flag(place)) {
+
nested_function->insert_result_into_repeat(nested_place(place), repeat,
+ nested_column,
arena);
+ null_map.push_back(0);
+ } else {
+ to_concrete.insert_default();
+ }
+ } else {
+ nested_function->insert_result_into_repeat(nested_place(place),
repeat, to, arena);
+ }
+ }
+
+ void insert_result_into_repeat_vec(const std::vector<AggregateDataPtr>&
places,
+ const size_t offset, const
std::vector<uint64_t>& repeats,
+ IColumn& to, const size_t num_rows,
+ Arena& arena) const override {
+ if constexpr (result_is_nullable) {
+ auto& to_concrete = assert_cast<ColumnNullable&>(to);
+ auto& nested_column = to_concrete.get_nested_column();
+
+ std::vector<AggregateDataPtr> nested_places;
+ std::vector<uint64_t> nested_repeats;
+ nested_places.reserve(num_rows);
+ nested_repeats.reserve(num_rows);
+
+ // Nullable result columns are stored as a nested value column
plus a null map.
+ // The nested aggregate function only writes value rows, so the
wrapper has to
+ // flush each contiguous non-NULL run before inserting a NULL row.
This preserves
+ // row order and keeps the nested column size aligned with the
null map size.
+ auto flush_nested = [&]() {
+ if (nested_places.empty()) {
+ return;
+ }
+ nested_function->insert_result_into_repeat_vec(nested_places,
offset + prefix_size,
+ nested_repeats,
nested_column,
+
nested_places.size(), arena);
+ to_concrete.get_null_map_column().insert_many_vals(0,
nested_places.size());
+ nested_places.clear();
+ nested_repeats.clear();
+ };
+
+ for (size_t i = 0; i != num_rows; ++i) {
+ if (get_flag(places[i] + offset)) {
+ nested_places.push_back(places[i]);
+ nested_repeats.push_back(repeats[i]);
+ } else {
+ flush_nested();
+ to_concrete.insert_default();
+ }
+ }
+ flush_nested();
+ } else {
+ nested_function->insert_result_into_repeat_vec(places, offset +
prefix_size, repeats,
+ to, num_rows,
arena);
+ }
+ }
};
template <typename NestFuction, bool result_is_nullable>
diff --git a/be/src/exprs/aggregate/aggregate_function_sum.h
b/be/src/exprs/aggregate/aggregate_function_sum.h
index c42c77f7d13..320de754e3f 100644
--- a/be/src/exprs/aggregate/aggregate_function_sum.h
+++ b/be/src/exprs/aggregate/aggregate_function_sum.h
@@ -81,6 +81,8 @@ class AggregateFunctionSum<T, TResult, Data> final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T,
TResult, Data>>,
UnaryExpression,
NullableAggregateFunction {
+ using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T,
TResult, Data>>;
+
public:
using ResultDataType = typename PrimitiveTypeTraits<TResult>::DataType;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
@@ -132,6 +134,44 @@ public:
column.get_data().push_back(this->data(place).get());
}
+ void insert_result_into_repeat(ConstAggregateDataPtr place, uint64_t
repeat, IColumn& to,
+ Arena& arena) const override {
+ if constexpr (is_decimalv2(TResult)) {
+ Base::insert_result_into_repeat(place, repeat, to, arena);
+ } else {
+ auto& column = assert_cast<ColVecResult&>(to);
+ auto result = this->data(place).get();
+ if constexpr (is_decimalv3(TResult)) {
+ result.value *= typename decltype(result)::NativeType(repeat);
+ } else {
+ result *= static_cast<decltype(result)>(repeat);
+ }
+ column.get_data().push_back(result);
+ }
+ }
+
+ void insert_result_into_repeat_vec(const std::vector<AggregateDataPtr>&
places,
+ const size_t offset, const
std::vector<uint64_t>& repeats,
+ IColumn& to, const size_t num_rows,
+ Arena& arena) const override {
+ if constexpr (is_decimalv2(TResult)) {
+ Base::insert_result_into_repeat_vec(places, offset, repeats, to,
num_rows, arena);
+ } else {
+ auto& column = assert_cast<ColVecResult&>(to);
+ auto& column_data = column.get_data();
+ column_data.reserve(column_data.size() + num_rows);
+ for (size_t i = 0; i != num_rows; ++i) {
+ auto result = this->data(places[i] + offset).get();
+ if constexpr (is_decimalv3(TResult)) {
+ result.value *= typename
decltype(result)::NativeType(repeats[i]);
+ } else {
+ result *= static_cast<decltype(result)>(repeats[i]);
+ }
+ column_data.push_back(result);
+ }
+ }
+ }
+
void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
MutableColumnPtr& dst, const size_t num_rows)
const override {
auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
diff --git a/be/src/exprs/vectorized_agg_fn.cpp
b/be/src/exprs/vectorized_agg_fn.cpp
index d2cff1e139f..8aa5f62ef60 100644
--- a/be/src/exprs/vectorized_agg_fn.cpp
+++ b/be/src/exprs/vectorized_agg_fn.cpp
@@ -321,6 +321,14 @@ void AggFnEvaluator::insert_result_info_vec(const
std::vector<AggregateDataPtr>&
_function->insert_result_into_vec(places, offset, *column, num_rows);
}
+void AggFnEvaluator::insert_result_info_repeat_vec(const
std::vector<AggregateDataPtr>& places,
+ size_t offset,
+ const
std::vector<uint64_t>& repeats,
+ IColumn* column, const
size_t num_rows,
+ Arena& arena) {
+ _function->insert_result_into_repeat_vec(places, offset, repeats, *column,
num_rows, arena);
+}
+
void AggFnEvaluator::reset(AggregateDataPtr place) {
_function->reset(place);
}
diff --git a/be/src/exprs/vectorized_agg_fn.h b/be/src/exprs/vectorized_agg_fn.h
index 7b69e7ee060..c0147feb841 100644
--- a/be/src/exprs/vectorized_agg_fn.h
+++ b/be/src/exprs/vectorized_agg_fn.h
@@ -87,6 +87,10 @@ public:
void insert_result_info_vec(const std::vector<AggregateDataPtr>& place,
size_t offset,
IColumn* column, const size_t num_rows);
+ void insert_result_info_repeat_vec(const std::vector<AggregateDataPtr>&
places, size_t offset,
+ const std::vector<uint64_t>& repeats,
IColumn* column,
+ const size_t num_rows, Arena& arena);
+
void reset(AggregateDataPtr place);
DataTypePtr& data_type() { return _data_type; }
diff --git a/be/test/exec/operator/groupjoin_operator_utils_test.cpp
b/be/test/exec/operator/groupjoin_operator_utils_test.cpp
new file mode 100644
index 00000000000..754af7d43a6
--- /dev/null
+++ b/be/test/exec/operator/groupjoin_operator_utils_test.cpp
@@ -0,0 +1,154 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "exec/operator/groupjoin_operator_utils.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <variant>
+#include <vector>
+
+#include "core/assert_cast.h"
+#include "core/column/column_nullable.h"
+#include "core/data_type/data_type_number.h"
+#include "exec/common/hash_table/hash_map_util.h"
+#include "exec/common/template_helpers.hpp"
+#include "exec/pipeline/dependency.h"
+#include "testutil/column_helper.h"
+
+namespace doris {
+namespace {
+
+#define ASSERT_OK(stmt) \
+ do { \
+ const auto& status = (stmt); \
+ ASSERT_TRUE(status.ok()) << status; \
+ } while (false)
+
+Status init_int32_groupjoin_state(GroupJoinSharedState* shared_state) {
+ return
init_hash_method<GroupJoinDataVariants>(shared_state->data_variants.get(),
+
{std::make_shared<DataTypeInt32>()}, true);
+}
+
+size_t count_hash_table_entries(GroupJoinSharedState* shared_state) {
+ return std::visit(
+ Overload {[&](std::monostate&) -> size_t { return 0; },
+ [&](auto& hash_method) -> size_t {
+ size_t count = 0;
+ hash_method.hash_table->for_each_mapped([&](auto&) {
++count; });
+ return count;
+ }},
+ shared_state->data_variants->method_variant);
+}
+
+TEST(GroupJoinOperatorUtilsTest, ExtractKeyColumnsBuildsCombinedNullMap) {
+ auto key0 = ColumnHelper::create_nullable_column<DataTypeInt32>({1, 2, 3,
4}, {0, 1, 0, 0});
+ auto key1 = ColumnHelper::create_nullable_column<DataTypeInt32>({10, 20,
30, 40}, {0, 0, 1, 0});
+ std::vector<ColumnPtr> key_columns_holder {key0, key1};
+
+ ColumnRawPtrs key_not_nullable_columns;
+ ColumnUInt8::MutablePtr null_map_column;
+ ASSERT_OK(groupjoin::extract_key_columns(4, key_columns_holder,
key_not_nullable_columns,
+ null_map_column));
+
+ ASSERT_TRUE(null_map_column.get() != nullptr);
+ ASSERT_EQ(key_not_nullable_columns.size(), 2);
+ EXPECT_EQ(key_not_nullable_columns[0],
+ &assert_cast<const ColumnNullable&>(*key0).get_nested_column());
+ EXPECT_EQ(key_not_nullable_columns[1],
+ &assert_cast<const ColumnNullable&>(*key1).get_nested_column());
+
+ const auto& null_map = null_map_column->get_data();
+ ASSERT_EQ(null_map.size(), 4);
+ EXPECT_EQ(null_map[0], 0);
+ EXPECT_EQ(null_map[1], 1);
+ EXPECT_EQ(null_map[2], 1);
+ EXPECT_EQ(null_map[3], 0);
+}
+
+TEST(GroupJoinOperatorUtilsTest, AddBuildCountsWithoutNullMapUsesBatchPath) {
+ GroupJoinSharedState shared_state;
+ ASSERT_OK(init_int32_groupjoin_state(&shared_state));
+
+ auto build_key = ColumnHelper::create_column<DataTypeInt32>({1, 2, 1, 3});
+ ColumnRawPtrs build_key_columns {build_key.get()};
+ std::vector<AggregateDataPtr> build_places(build_key->size());
+ std::vector<int> aggregate_indices;
+ ASSERT_OK(groupjoin::add_build_counts_by_key(&shared_state,
*shared_state.arena,
+ build_key_columns,
+
static_cast<uint32_t>(build_key->size()), nullptr,
+ aggregate_indices,
build_places.data()));
+
+ EXPECT_EQ(count_hash_table_entries(&shared_state), 3);
+
+ auto probe_key = ColumnHelper::create_column<DataTypeInt32>({1, 2, 3, 4});
+ ColumnRawPtrs probe_key_columns {probe_key.get()};
+ std::vector<AggregateDataPtr> probe_places(probe_key->size());
+ int64_t matched_rows = 0;
+ uint32_t matched_probe_rows = 0;
+ ASSERT_OK(groupjoin::update_probe_counts(&shared_state,
*shared_state.arena, probe_key_columns,
+
static_cast<uint32_t>(probe_key->size()), nullptr,
+ aggregate_indices,
probe_places.data(), matched_rows,
+ matched_probe_rows));
+
+ EXPECT_EQ(matched_probe_rows, 3);
+ EXPECT_EQ(matched_rows, 4);
+}
+
+TEST(GroupJoinOperatorUtilsTest, AddBuildCountsSkipsNullKeyRowsBeforeEmplace) {
+ GroupJoinSharedState shared_state;
+ ASSERT_OK(init_int32_groupjoin_state(&shared_state));
+
+ auto build_key = ColumnHelper::create_nullable_column<DataTypeInt32>({1,
0, 1, 2, 0, 2},
+ {0,
1, 0, 0, 1, 0});
+ std::vector<ColumnPtr> key_columns_holder {build_key};
+ ColumnRawPtrs build_key_not_nullable_columns;
+ ColumnUInt8::MutablePtr null_map_column;
+ ASSERT_OK(groupjoin::extract_key_columns(build_key->size(),
key_columns_holder,
+ build_key_not_nullable_columns,
null_map_column));
+ ASSERT_TRUE(null_map_column.get() != nullptr);
+
+ std::vector<AggregateDataPtr> build_places(build_key->size());
+ std::vector<int> aggregate_indices;
+ ASSERT_OK(groupjoin::add_build_counts_by_key(
+ &shared_state, *shared_state.arena, build_key_not_nullable_columns,
+ static_cast<uint32_t>(build_key->size()),
null_map_column->get_data().data(),
+ aggregate_indices, build_places.data()));
+
+ EXPECT_EQ(count_hash_table_entries(&shared_state), 2);
+ // This side has no aggregate function in this case, so all places stay
nullptr.
+ // The important contract here is that NULL-key rows are skipped before
emplace.
+ EXPECT_EQ(build_places[1], nullptr);
+ EXPECT_EQ(build_places[4], nullptr);
+
+ auto probe_key = ColumnHelper::create_column<DataTypeInt32>({0, 1, 2});
+ ColumnRawPtrs probe_key_columns {probe_key.get()};
+ std::vector<AggregateDataPtr> probe_places(probe_key->size());
+ int64_t matched_rows = 0;
+ uint32_t matched_probe_rows = 0;
+ ASSERT_OK(groupjoin::update_probe_counts(&shared_state,
*shared_state.arena, probe_key_columns,
+
static_cast<uint32_t>(probe_key->size()), nullptr,
+ aggregate_indices,
probe_places.data(), matched_rows,
+ matched_probe_rows));
+
+ EXPECT_EQ(matched_probe_rows, 2);
+ EXPECT_EQ(matched_rows, 4);
+}
+
+} // namespace
+} // namespace doris
diff --git a/be/test/exprs/aggregate/aggregate_function_repeat_test.cpp
b/be/test/exprs/aggregate/aggregate_function_repeat_test.cpp
new file mode 100644
index 00000000000..ba561420fc4
--- /dev/null
+++ b/be/test/exprs/aggregate/aggregate_function_repeat_test.cpp
@@ -0,0 +1,438 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "common/exception.h"
+#include "core/assert_cast.h"
+#include "core/column/column_nullable.h"
+#include "core/column/column_vector.h"
+#include "core/data_type/data_type_nullable.h"
+#include "core/data_type/data_type_number.h"
+#include "exprs/aggregate/aggregate_function.h"
+#include "exprs/aggregate/aggregate_function_simple_factory.h"
+#include "testutil/column_helper.h"
+
+namespace doris {
+namespace {
+
+AggregateFunctionPtr get_agg_function(const std::string& name, const
DataTypes& argument_types,
+ bool result_is_nullable,
AggregateFunctionAttr attr = {}) {
+ auto function = AggregateFunctionSimpleFactory::instance().get(name,
argument_types, nullptr,
+
result_is_nullable, -1, attr);
+ if (function == nullptr) {
+ throw Exception(ErrorCode::INTERNAL_ERROR, "failed to create aggregate
function {}", name);
+ }
+ return function;
+}
+
+AggregateDataPtr create_state(const AggregateFunctionPtr& function, Arena&
arena) {
+ auto* place = arena.aligned_alloc(function->size_of_data(),
function->align_of_data());
+ function->create(place);
+ return place;
+}
+
+void destroy_states(const AggregateFunctionPtr& function,
+ const std::vector<AggregateDataPtr>& places) {
+ for (auto* place : places) {
+ function->destroy(place);
+ }
+}
+
+void destroy_states(const AggregateFunctionPtr& function,
+ const std::vector<AggregateDataPtr>& places, size_t
offset) {
+ for (auto* place : places) {
+ function->destroy(place + offset);
+ }
+}
+
+TEST(AggregateFunctionRepeatTest, CountRepeatOutput) {
+ Arena arena;
+ auto function = get_agg_function("count", {}, false);
+ std::vector<AggregateDataPtr> places {create_state(function, arena),
+ create_state(function, arena)};
+
+ function->add(places[0], nullptr, 0, arena);
+ function->add(places[0], nullptr, 0, arena);
+ for (int i = 0; i < 5; ++i) {
+ function->add(places[1], nullptr, 0, arena);
+ }
+
+ ColumnInt64 single_result;
+ function->insert_result_into_repeat(places[0], 3, single_result, arena);
+ ASSERT_EQ(single_result.size(), 1);
+ EXPECT_EQ(single_result.get_element(0), 6);
+
+ ColumnInt64 result;
+ std::vector<uint64_t> repeats {3, 4};
+ function->insert_result_into_repeat_vec(places, 0, repeats, result,
places.size(), arena);
+
+ ASSERT_EQ(result.size(), 2);
+ EXPECT_EQ(result.get_element(0), 6);
+ EXPECT_EQ(result.get_element(1), 20);
+ destroy_states(function, places);
+}
+
+TEST(AggregateFunctionRepeatTest, CountRepeatZeroOutput) {
+ Arena arena;
+ auto function = get_agg_function("count", {}, false);
+ std::vector<AggregateDataPtr> places {create_state(function, arena)};
+
+ function->add(places[0], nullptr, 0, arena);
+ function->add(places[0], nullptr, 0, arena);
+ function->add(places[0], nullptr, 0, arena);
+
+ ColumnInt64 single_result;
+ function->insert_result_into_repeat(places[0], 0, single_result, arena);
+ ASSERT_EQ(single_result.size(), 1);
+ EXPECT_EQ(single_result.get_element(0), 0);
+
+ ColumnInt64 result;
+ std::vector<uint64_t> repeats {0};
+ function->insert_result_into_repeat_vec(places, 0, repeats, result,
places.size(), arena);
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_EQ(result.get_element(0), 0);
+
+ destroy_states(function, places);
+}
+
+TEST(AggregateFunctionRepeatTest, CountNullableRepeatOutput) {
+ Arena arena;
+ auto nullable_type =
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeInt32>());
+ auto function = get_agg_function("count", {nullable_type}, false);
+ std::vector<AggregateDataPtr> places {create_state(function, arena),
+ create_state(function, arena)};
+
+ auto input =
+ ColumnHelper::create_nullable_column<DataTypeInt32>({1, 2, 3, 4,
5}, {0, 1, 0, 0, 1});
+ const IColumn* columns[] = {input.get()};
+ function->add(places[0], columns, 0, arena);
+ function->add(places[0], columns, 1, arena);
+ function->add(places[0], columns, 2, arena);
+ function->add(places[1], columns, 3, arena);
+ function->add(places[1], columns, 4, arena);
+
+ ColumnInt64 result;
+ std::vector<uint64_t> repeats {5, 7};
+ function->insert_result_into_repeat_vec(places, 0, repeats, result,
places.size(), arena);
+
+ ASSERT_EQ(result.size(), 2);
+ EXPECT_EQ(result.get_element(0), 10);
+ EXPECT_EQ(result.get_element(1), 7);
+ destroy_states(function, places);
+}
+
+TEST(AggregateFunctionRepeatTest, RepeatOutputHonorsStateOffset) {
+ Arena arena;
+ auto function = get_agg_function("sum",
{std::make_shared<DataTypeInt32>()}, true);
+ const size_t offset = function->align_of_data();
+ std::vector<AggregateDataPtr> places {
+ arena.aligned_alloc(offset + function->size_of_data(),
function->align_of_data()),
+ arena.aligned_alloc(offset + function->size_of_data(),
function->align_of_data())};
+ for (auto* place : places) {
+ function->create(place + offset);
+ }
+
+ auto input = ColumnHelper::create_column<DataTypeInt32>({2, 3, 5, 7});
+ const IColumn* columns[] = {input.get()};
+ function->add(places[0] + offset, columns, 0, arena);
+ function->add(places[0] + offset, columns, 1, arena);
+ function->add(places[1] + offset, columns, 2, arena);
+ function->add(places[1] + offset, columns, 3, arena);
+
+ ColumnInt64 result;
+ std::vector<uint64_t> repeats {4, 6};
+ function->insert_result_into_repeat_vec(places, offset, repeats, result,
places.size(), arena);
+
+ ASSERT_EQ(result.size(), 2);
+ EXPECT_EQ(result.get_element(0), 20);
+ EXPECT_EQ(result.get_element(1), 72);
+ destroy_states(function, places, offset);
+}
+
+TEST(AggregateFunctionRepeatTest, SumRepeatOutput) {
+ Arena arena;
+ auto function = get_agg_function("sum",
{std::make_shared<DataTypeInt32>()}, true);
+ std::vector<AggregateDataPtr> places {create_state(function, arena),
+ create_state(function, arena)};
+
+ auto input = ColumnHelper::create_column<DataTypeInt32>({1, 2, 4, 6});
+ const IColumn* columns[] = {input.get()};
+ function->add(places[0], columns, 0, arena);
+ function->add(places[0], columns, 1, arena);
+ function->add(places[1], columns, 2, arena);
+ function->add(places[1], columns, 3, arena);
+
+ ColumnInt64 result;
+ std::vector<uint64_t> repeats {3, 2};
+ function->insert_result_into_repeat_vec(places, 0, repeats, result,
places.size(), arena);
+
+ ASSERT_EQ(result.size(), 2);
+ EXPECT_EQ(result.get_element(0), 9);
+ EXPECT_EQ(result.get_element(1), 20);
+ destroy_states(function, places);
+}
+
+TEST(AggregateFunctionRepeatTest, MinMaxRepeatOutput) {
+ Arena arena;
+ auto min_function = get_agg_function("min",
{std::make_shared<DataTypeInt32>()}, true);
+ auto max_function = get_agg_function("max",
{std::make_shared<DataTypeInt32>()}, true);
+
+ auto input = ColumnHelper::create_column<DataTypeInt32>({4, 2, 1, 7});
+ const IColumn* columns[] = {input.get()};
+ std::vector<uint64_t> repeats {10, 20};
+
+ std::vector<AggregateDataPtr> min_places {create_state(min_function,
arena),
+ create_state(min_function,
arena)};
+ min_function->add(min_places[0], columns, 0, arena);
+ min_function->add(min_places[0], columns, 1, arena);
+ min_function->add(min_places[1], columns, 2, arena);
+ min_function->add(min_places[1], columns, 3, arena);
+ ColumnInt32 min_result;
+ min_function->insert_result_into_repeat_vec(min_places, 0, repeats,
min_result,
+ min_places.size(), arena);
+
+ ASSERT_EQ(min_result.size(), 2);
+ EXPECT_EQ(min_result.get_element(0), 2);
+ EXPECT_EQ(min_result.get_element(1), 1);
+ destroy_states(min_function, min_places);
+
+ std::vector<AggregateDataPtr> max_places {create_state(max_function,
arena),
+ create_state(max_function,
arena)};
+ max_function->add(max_places[0], columns, 0, arena);
+ max_function->add(max_places[0], columns, 1, arena);
+ max_function->add(max_places[1], columns, 2, arena);
+ max_function->add(max_places[1], columns, 3, arena);
+ ColumnInt32 max_result;
+ max_function->insert_result_into_repeat_vec(max_places, 0, repeats,
max_result,
+ max_places.size(), arena);
+
+ ASSERT_EQ(max_result.size(), 2);
+ EXPECT_EQ(max_result.get_element(0), 4);
+ EXPECT_EQ(max_result.get_element(1), 7);
+ destroy_states(max_function, max_places);
+}
+
+TEST(AggregateFunctionRepeatTest, NullableV2SumRepeatOutput) {
+ Arena arena;
+ AggregateFunctionAttr attr;
+ attr.enable_aggregate_function_null_v2 = true;
+ auto nullable_type =
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeInt32>());
+ auto function = get_agg_function("sum", {nullable_type}, true, attr);
+ std::vector<AggregateDataPtr> places {create_state(function, arena),
+ create_state(function, arena),
+ create_state(function, arena)};
+
+ auto input = ColumnHelper::create_nullable_column<DataTypeInt32>({5, 0,
3}, {0, 1, 0});
+ const IColumn* columns[] = {input.get()};
+ function->add(places[0], columns, 0, arena);
+ function->add(places[1], columns, 1, arena);
+ function->add(places[2], columns, 2, arena);
+
+ auto nested_result = ColumnInt64::create();
+ auto null_map = ColumnUInt8::create();
+ auto result = ColumnNullable::create(std::move(nested_result),
std::move(null_map));
+ std::vector<uint64_t> repeats {4, 7, 6};
+ function->insert_result_into_repeat_vec(places, 0, repeats, *result,
places.size(), arena);
+
+ ASSERT_EQ(result->size(), 3);
+ EXPECT_FALSE(result->is_null_at(0));
+ EXPECT_TRUE(result->is_null_at(1));
+ EXPECT_FALSE(result->is_null_at(2));
+ EXPECT_EQ(assert_cast<const
ColumnInt64&>(result->get_nested_column()).get_element(0), 20);
+ EXPECT_EQ(assert_cast<const
ColumnInt64&>(result->get_nested_column()).get_element(2), 18);
+
+ auto single_nested_result = ColumnInt64::create();
+ auto single_null_map = ColumnUInt8::create();
+ auto single_result =
+ ColumnNullable::create(std::move(single_nested_result),
std::move(single_null_map));
+ function->insert_result_into_repeat(places[0], 4, *single_result, arena);
+ function->insert_result_into_repeat(places[1], 7, *single_result, arena);
+ function->insert_result_into_repeat(places[2], 6, *single_result, arena);
+ ASSERT_EQ(single_result->size(), 3);
+ EXPECT_FALSE(single_result->is_null_at(0));
+ EXPECT_TRUE(single_result->is_null_at(1));
+ EXPECT_FALSE(single_result->is_null_at(2));
+ EXPECT_EQ(assert_cast<const
ColumnInt64&>(single_result->get_nested_column()).get_element(0),
+ 20);
+ EXPECT_EQ(assert_cast<const
ColumnInt64&>(single_result->get_nested_column()).get_element(2),
+ 18);
+
+ destroy_states(function, places);
+}
+
+TEST(AggregateFunctionRepeatTest,
NullableV2RepeatOutputAppendsAfterExistingRows) {
+ Arena arena;
+ AggregateFunctionAttr attr;
+ attr.enable_aggregate_function_null_v2 = true;
+ auto nullable_type =
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeInt32>());
+ auto function = get_agg_function("sum", {nullable_type}, true, attr);
+ std::vector<AggregateDataPtr> places {create_state(function, arena),
+ create_state(function, arena),
+ create_state(function, arena)};
+
+ auto input = ColumnHelper::create_nullable_column<DataTypeInt32>({5, 0,
3}, {0, 1, 0});
+ const IColumn* columns[] = {input.get()};
+ function->add(places[0], columns, 0, arena);
+ function->add(places[1], columns, 1, arena);
+ function->add(places[2], columns, 2, arena);
+
+ auto nested_result = ColumnInt64::create();
+ auto null_map = ColumnUInt8::create();
+ auto result = ColumnNullable::create(std::move(nested_result),
std::move(null_map));
+ result->insert_default();
+
+ std::vector<uint64_t> repeats {4, 7, 6};
+ function->insert_result_into_repeat_vec(places, 0, repeats, *result,
places.size(), arena);
+
+ ASSERT_EQ(result->size(), 4);
+ EXPECT_TRUE(result->is_null_at(0));
+ EXPECT_FALSE(result->is_null_at(1));
+ EXPECT_TRUE(result->is_null_at(2));
+ EXPECT_FALSE(result->is_null_at(3));
+ EXPECT_EQ(result->get_nested_column().size(),
result->get_null_map_column().size());
+ EXPECT_EQ(assert_cast<const
ColumnInt64&>(result->get_nested_column()).get_element(1), 20);
+ EXPECT_EQ(assert_cast<const
ColumnInt64&>(result->get_nested_column()).get_element(3), 18);
+
+ destroy_states(function, places);
+}
+
+TEST(AggregateFunctionRepeatTest, NullableV2SumRepeatOutputAllValidRun) {
+ Arena arena;
+ AggregateFunctionAttr attr;
+ attr.enable_aggregate_function_null_v2 = true;
+ auto nullable_type =
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeInt32>());
+ auto function = get_agg_function("sum", {nullable_type}, true, attr);
+ std::vector<AggregateDataPtr> places {create_state(function, arena),
+ create_state(function, arena),
+ create_state(function, arena)};
+
+ auto input = ColumnHelper::create_nullable_column<DataTypeInt32>({2, 4,
6}, {0, 0, 0});
+ const IColumn* columns[] = {input.get()};
+ function->add(places[0], columns, 0, arena);
+ function->add(places[1], columns, 1, arena);
+ function->add(places[2], columns, 2, arena);
+
+ auto nested_result = ColumnInt64::create();
+ auto null_map = ColumnUInt8::create();
+ auto result = ColumnNullable::create(std::move(nested_result),
std::move(null_map));
+ std::vector<uint64_t> repeats {3, 5, 7};
+ function->insert_result_into_repeat_vec(places, 0, repeats, *result,
places.size(), arena);
+
+ ASSERT_EQ(result->size(), 3);
+ EXPECT_FALSE(result->is_null_at(0));
+ EXPECT_FALSE(result->is_null_at(1));
+ EXPECT_FALSE(result->is_null_at(2));
+ const auto& nested_column = assert_cast<const
ColumnInt64&>(result->get_nested_column());
+ EXPECT_EQ(nested_column.get_element(0), 6);
+ EXPECT_EQ(nested_column.get_element(1), 20);
+ EXPECT_EQ(nested_column.get_element(2), 42);
+
+ destroy_states(function, places);
+}
+
+TEST(AggregateFunctionRepeatTest, NullableV2SumRepeatOutputAllNull) {
+ Arena arena;
+ AggregateFunctionAttr attr;
+ attr.enable_aggregate_function_null_v2 = true;
+ auto nullable_type =
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeInt32>());
+ auto function = get_agg_function("sum", {nullable_type}, true, attr);
+ std::vector<AggregateDataPtr> places {create_state(function, arena),
+ create_state(function, arena)};
+
+ auto input = ColumnHelper::create_nullable_column<DataTypeInt32>({0, 0},
{1, 1});
+ const IColumn* columns[] = {input.get()};
+ function->add(places[0], columns, 0, arena);
+ function->add(places[1], columns, 1, arena);
+
+ auto nested_result = ColumnInt64::create();
+ auto null_map = ColumnUInt8::create();
+ auto result = ColumnNullable::create(std::move(nested_result),
std::move(null_map));
+ std::vector<uint64_t> repeats {3, 5};
+ function->insert_result_into_repeat_vec(places, 0, repeats, *result,
places.size(), arena);
+
+ ASSERT_EQ(result->size(), 2);
+ EXPECT_TRUE(result->is_null_at(0));
+ EXPECT_TRUE(result->is_null_at(1));
+ EXPECT_EQ(result->get_nested_column().size(),
result->get_null_map_column().size());
+
+ destroy_states(function, places);
+}
+
+TEST(AggregateFunctionRepeatTest, NullableV2MinMaxRepeatOutput) {
+ Arena arena;
+ AggregateFunctionAttr attr;
+ attr.enable_aggregate_function_null_v2 = true;
+ auto nullable_type =
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeInt32>());
+ auto min_function = get_agg_function("min", {nullable_type}, true, attr);
+ auto max_function = get_agg_function("max", {nullable_type}, true, attr);
+
+ auto input =
+ ColumnHelper::create_nullable_column<DataTypeInt32>({5, 0, 9, 1,
0}, {0, 1, 0, 0, 1});
+ const IColumn* columns[] = {input.get()};
+ std::vector<uint64_t> repeats {3, 7, 11};
+
+ std::vector<AggregateDataPtr> min_places {create_state(min_function,
arena),
+ create_state(min_function,
arena),
+ create_state(min_function,
arena)};
+ min_function->add(min_places[0], columns, 0, arena);
+ min_function->add(min_places[1], columns, 1, arena);
+ min_function->add(min_places[2], columns, 2, arena);
+ min_function->add(min_places[2], columns, 3, arena);
+
+ auto min_nested_result = ColumnInt32::create();
+ auto min_null_map = ColumnUInt8::create();
+ auto min_result = ColumnNullable::create(std::move(min_nested_result),
std::move(min_null_map));
+ min_function->insert_result_into_repeat_vec(min_places, 0, repeats,
*min_result,
+ min_places.size(), arena);
+
+ ASSERT_EQ(min_result->size(), 3);
+ EXPECT_FALSE(min_result->is_null_at(0));
+ EXPECT_TRUE(min_result->is_null_at(1));
+ EXPECT_FALSE(min_result->is_null_at(2));
+ EXPECT_EQ(assert_cast<const
ColumnInt32&>(min_result->get_nested_column()).get_element(0), 5);
+ EXPECT_EQ(assert_cast<const
ColumnInt32&>(min_result->get_nested_column()).get_element(2), 1);
+ destroy_states(min_function, min_places);
+
+ std::vector<AggregateDataPtr> max_places {create_state(max_function,
arena),
+ create_state(max_function,
arena),
+ create_state(max_function,
arena)};
+ max_function->add(max_places[0], columns, 0, arena);
+ max_function->add(max_places[1], columns, 4, arena);
+ max_function->add(max_places[2], columns, 2, arena);
+ max_function->add(max_places[2], columns, 3, arena);
+
+ auto max_nested_result = ColumnInt32::create();
+ auto max_null_map = ColumnUInt8::create();
+ auto max_result = ColumnNullable::create(std::move(max_nested_result),
std::move(max_null_map));
+ max_function->insert_result_into_repeat_vec(max_places, 0, repeats,
*max_result,
+ max_places.size(), arena);
+
+ ASSERT_EQ(max_result->size(), 3);
+ EXPECT_FALSE(max_result->is_null_at(0));
+ EXPECT_TRUE(max_result->is_null_at(1));
+ EXPECT_FALSE(max_result->is_null_at(2));
+ EXPECT_EQ(assert_cast<const
ColumnInt32&>(max_result->get_nested_column()).get_element(0), 5);
+ EXPECT_EQ(assert_cast<const
ColumnInt32&>(max_result->get_nested_column()).get_element(2), 9);
+ destroy_states(max_function, max_places);
+}
+
+} // namespace
+} // namespace doris
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]