This is an automated email from the ASF dual-hosted git repository. yangzhg pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push: new f8d086d [feature](rpc) (experimental)Support implement UDF through GRPC protocol. (#7519) f8d086d is described below commit f8d086d87fa3291d0efce2dee46f999ce9087c94 Author: Zhengguo Yang <yangz...@gmail.com> AuthorDate: Tue Feb 8 09:25:09 2022 +0800 [feature](rpc) (experimental)Support implement UDF through GRPC protocol. (#7519) Support implement UDF through GRPC protocol. This brings several benefits: 1. The udf implementation language is not limited to c++, users can use any familiar language to implement udf 2. UDF is decoupled from Doris, udf will not cause doris coredump, udf computing resources are separated from doris, and doris services are not affected But RPC's UDF has a fixed overhead, so its performance is much slower than C++ UDF, especially when the amount of data is large. Create function like ``` CREATE FUNCTION rpc_add(INT, INT) RETURNS INT PROPERTIES ( "SYMBOL"="add_int", "OBJECT_FILE"="127.0.0.1:9999", "TYPE"="RPC" ); ``` Function service need to implement `check_fn` and `fn_call` methods Note: THIS IS AN EXPERIMENTAL FEATURE, THE INTERFACE AND DATA STRUCTURE MAY BE CHANGED IN FUTURE !!! --- be/src/common/config.h | 6 +- be/src/common/status.h | 2 +- be/src/exec/tablet_sink.cpp | 127 +++-- be/src/exprs/CMakeLists.txt | 1 + be/src/exprs/expr.cpp | 3 + be/src/exprs/expr_context.h | 1 + be/src/exprs/rpc_fn_call.cpp | 327 +++++++++++++ be/src/exprs/rpc_fn_call.h | 63 +++ be/src/exprs/runtime_filter_rpc.cpp | 7 +- be/src/gen_cpp/CMakeLists.txt | 2 +- be/src/http/action/check_rpc_channel_action.cpp | 4 +- be/src/http/action/reset_rpc_channel_action.cpp | 12 +- be/src/plugin/plugin_loader.cpp | 6 +- be/src/runtime/data_stream_sender.cpp | 4 +- be/src/runtime/exec_env.h | 21 +- be/src/runtime/exec_env_init.cpp | 8 +- be/src/runtime/runtime_filter_mgr.cpp | 4 +- be/src/service/internal_service.cpp | 12 +- be/src/udf/udf.cpp | 15 + be/src/udf/udf_internal.h | 4 + be/src/util/CMakeLists.txt | 2 +- .../{brpc_stub_cache.cpp => brpc_client_cache.cpp} | 20 +- be/src/util/brpc_client_cache.h | 150 ++++++ be/src/util/brpc_stub_cache.h | 159 ------- be/src/util/doris_metrics.h | 16 +- be/src/vec/CMakeLists.txt | 1 + be/src/vec/columns/column_decimal.h | 9 +- be/src/vec/exprs/vectorized_fn_call.cpp | 10 +- be/src/vec/functions/function_rpc.cpp | 527 +++++++++++++++++++++ be/src/vec/functions/function_rpc.h | 68 +++ be/src/vec/sink/vdata_stream_sender.cpp | 8 +- be/src/vec/sink/vdata_stream_sender.h | 24 +- be/test/exec/tablet_sink_test.cpp | 8 +- be/test/http/stream_load_test.cpp | 11 +- be/test/util/CMakeLists.txt | 2 +- ...b_cache_test.cpp => brpc_client_cache_test.cpp} | 24 +- be/test/vec/runtime/vdata_stream_test.cpp | 17 +- .../apache/doris/analysis/CreateFunctionStmt.java | 159 ++++++- .../org/apache/doris/catalog/ScalarFunction.java | 5 +- .../main/java/org/apache/doris/common/Status.java | 2 +- .../main/java/org/apache/doris/qe/Coordinator.java | 2 +- .../doris/load/sync/canal/CanalSyncDataTest.java | 13 +- .../apache/doris/utframe/MockedBackendFactory.java | 12 +- gensrc/proto/function_service.proto | 63 +++ gensrc/proto/internal_service.proto | 10 - gensrc/proto/status.proto | 27 -- gensrc/proto/types.proto | 151 ++++++ gensrc/thrift/Types.thrift | 5 +- run-be-ut.sh | 1 + 49 files changed, 1765 insertions(+), 370 deletions(-) diff --git a/be/src/common/config.h b/be/src/common/config.h index f4b38ba..2730b62 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -653,7 +653,7 @@ CONF_mInt32(default_remote_storage_s3_max_conn, "50"); CONF_mInt32(default_remote_storage_s3_request_timeout_ms, "3000"); CONF_mInt32(default_remote_storage_s3_conn_timeout_ms, "1000"); // Set to true to disable the minidump feature. -CONF_Bool(disable_minidump , "false"); +CONF_Bool(disable_minidump, "false"); // The dir to save minidump file. // Make sure that the user who run Doris has permission to create and visit this dir, @@ -688,7 +688,11 @@ CONF_mInt32(load_task_high_priority_threshold_second, "120"); // Increase this config may avoid rpc timeout. CONF_mInt32(min_load_rpc_timeout_ms, "20000"); +// use which protocol to access function service, candicate is baidu_std/h2:grpc +CONF_String(function_service_protocol, "h2:grpc"); +// use which load balancer to select server to connect +CONF_String(rpc_load_balancer, "rr"); } // namespace config diff --git a/be/src/common/status.h b/be/src/common/status.h index 89bd6fe..23bf764 100644 --- a/be/src/common/status.h +++ b/be/src/common/status.h @@ -10,7 +10,7 @@ #include "common/compiler_util.h" #include "common/logging.h" #include "gen_cpp/Status_types.h" // for TStatus -#include "gen_cpp/status.pb.h" // for PStatus +#include "gen_cpp/types.pb.h" // for PStatus #include "util/slice.h" // for Slice namespace doris { diff --git a/be/src/exec/tablet_sink.cpp b/be/src/exec/tablet_sink.cpp index 7f3a62c..2fe3492 100644 --- a/be/src/exec/tablet_sink.cpp +++ b/be/src/exec/tablet_sink.cpp @@ -18,6 +18,7 @@ #include "exec/tablet_sink.h" #include <fmt/format.h> + #include <sstream> #include <string> @@ -31,7 +32,7 @@ #include "runtime/tuple_row.h" #include "service/backend_options.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/debug/sanitizer_scopes.h" #include "util/monotime.h" #include "util/proto_util.h" @@ -85,7 +86,8 @@ Status NodeChannel::init(RuntimeState* state) { _batch_size = state->batch_size(); _cur_batch.reset(new RowBatch(*_row_desc, _batch_size, _parent->_mem_tracker.get())); - _stub = state->exec_env()->brpc_stub_cache()->get_stub(_node_info.host, _node_info.brpc_port); + _stub = state->exec_env()->brpc_internal_client_cache()->get_client(_node_info.host, + _node_info.brpc_port); if (_stub == nullptr) { LOG(WARNING) << "Get rpc stub failed, host=" << _node_info.host << ", port=" << _node_info.brpc_port; @@ -156,9 +158,10 @@ void NodeChannel::_cancel_with_msg(const std::string& msg) { Status NodeChannel::open_wait() { _open_closure->join(); if (_open_closure->cntl.Failed()) { - if (!ExecEnv::GetInstance()->brpc_stub_cache()->available(_stub, _node_info.host, - _node_info.brpc_port)) { - ExecEnv::GetInstance()->brpc_stub_cache()->erase(_open_closure->cntl.remote_side()); + if (!ExecEnv::GetInstance()->brpc_internal_client_cache()->available( + _stub, _node_info.host, _node_info.brpc_port)) { + ExecEnv::GetInstance()->brpc_internal_client_cache()->erase( + _open_closure->cntl.remote_side()); } std::stringstream ss; ss << "failed to open tablet writer, error=" << berror(_open_closure->cntl.ErrorCode()) @@ -193,7 +196,7 @@ Status NodeChannel::open_wait() { bool is_last_rpc) { Status status(result.status()); if (status.ok()) { - // if has error tablet, handle them first + // if has error tablet, handle them first for (auto& error : result.tablet_errors()) { _index_channel->mark_as_failed(this, error.msg(), error.tablet_id()); } @@ -313,8 +316,9 @@ Status NodeChannel::add_row(BlockRow& block_row, int64_t tablet_id) { } DCHECK_NE(row_no, RowBatch::INVALID_ROW_INDEX); - _cur_batch->get_row(row_no)->set_tuple(0, - block_row.first->deep_copy_tuple(*_tuple_desc, _cur_batch->tuple_data_pool(), block_row.second, 0, true)); + _cur_batch->get_row(row_no)->set_tuple( + 0, block_row.first->deep_copy_tuple(*_tuple_desc, _cur_batch->tuple_data_pool(), + block_row.second, 0, true)); _cur_batch->commit_last_row(); _cur_add_batch_request.add_tablet_ids(tablet_id); return Status::OK(); @@ -338,7 +342,8 @@ Status NodeChannel::mark_close() { _pending_batches.emplace(std::move(_cur_batch), _cur_add_batch_request); _pending_batches_num++; DCHECK(_pending_batches.back().second.eos()); - LOG(INFO) << channel_info() << " mark closed, left pending batch size: " << _pending_batches.size(); + LOG(INFO) << channel_info() + << " mark closed, left pending batch size: " << _pending_batches.size(); } _eos_is_produced = true; @@ -377,7 +382,7 @@ Status NodeChannel::close_wait(RuntimeState* state) { std::make_move_iterator(_tablet_commit_infos.begin()), std::make_move_iterator(_tablet_commit_infos.end())); - _index_channel->set_error_tablet_in_state(state); + _index_channel->set_error_tablet_in_state(state); return Status::OK(); } @@ -455,7 +460,7 @@ void NodeChannel::try_send_batch() { size_t uncompressed_bytes = 0, compressed_bytes = 0; Status st = row_batch->serialize(request.mutable_row_batch(), &uncompressed_bytes, &compressed_bytes, _tuple_data_buffer_ptr); if (!st.ok()) { - cancel(fmt::format("{}, err: {}", channel_info(), st.get_error_msg())); + cancel(fmt::format("{}, err: {}", channel_info(), st.get_error_msg())); return; } if (compressed_bytes >= double(config::brpc_max_body_size) * 0.95f) { @@ -541,8 +546,8 @@ Status IndexChannel::init(RuntimeState* state, const std::vector<TTabletWithPart NodeChannel* channel = nullptr; auto it = _node_channels.find(node_id); if (it == _node_channels.end()) { - channel = _parent->_pool->add( - new NodeChannel(_parent, this, node_id, _schema_hash)); + channel = + _parent->_pool->add(new NodeChannel(_parent, this, node_id, _schema_hash)); _node_channels.emplace(node_id, channel); } else { channel = it->second; @@ -586,41 +591,44 @@ void IndexChannel::add_row(BlockRow& block_row, int64_t tablet_id) { } } -void IndexChannel::mark_as_failed(const NodeChannel* ch, const std::string& err, int64_t tablet_id) { +void IndexChannel::mark_as_failed(const NodeChannel* ch, const std::string& err, + int64_t tablet_id) { const auto& it = _tablets_by_channel.find(ch->node_id()); if (it == _tablets_by_channel.end()) { return; } { - std::lock_guard<SpinLock> l(_fail_lock); + std::lock_guard<SpinLock> l(_fail_lock); if (tablet_id == -1) { for (const auto the_tablet_id : it->second) { _failed_channels[the_tablet_id].insert(ch->node_id()); _failed_channels_msgs.emplace(the_tablet_id, err + ", host: " + ch->host()); if (_failed_channels[the_tablet_id].size() >= ((_parent->_num_replicas + 1) / 2)) { - _intolerable_failure_status = Status::InternalError(_failed_channels_msgs[the_tablet_id]); + _intolerable_failure_status = + Status::InternalError(_failed_channels_msgs[the_tablet_id]); } } } else { _failed_channels[tablet_id].insert(ch->node_id()); _failed_channels_msgs.emplace(tablet_id, err + ", host: " + ch->host()); if (_failed_channels[tablet_id].size() >= ((_parent->_num_replicas + 1) / 2)) { - _intolerable_failure_status = Status::InternalError(_failed_channels_msgs[tablet_id]); + _intolerable_failure_status = + Status::InternalError(_failed_channels_msgs[tablet_id]); } } } } Status IndexChannel::check_intolerable_failure() { - std::lock_guard<SpinLock> l(_fail_lock); + std::lock_guard<SpinLock> l(_fail_lock); return _intolerable_failure_status; } void IndexChannel::set_error_tablet_in_state(RuntimeState* state) { std::vector<TErrorTabletInfo>& error_tablet_infos = state->error_tablet_infos(); - std::lock_guard<SpinLock> l(_fail_lock); + std::lock_guard<SpinLock> l(_fail_lock); for (const auto& it : _failed_channels_msgs) { TErrorTabletInfo error_info; error_info.__set_tabletId(it.first); @@ -684,7 +692,8 @@ Status OlapTableSink::prepare(RuntimeState* state) { _sender_id = state->per_fragment_instance_idx(); _num_senders = state->num_per_fragment_instances(); - _is_high_priority = (state->query_options().query_timeout <= config::load_task_high_priority_threshold_second); + _is_high_priority = (state->query_options().query_timeout <= + config::load_task_high_priority_threshold_second); // profile must add to state's object pool _profile = state->obj_pool()->add(new RuntimeProfile("OlapTableSink")); @@ -810,7 +819,10 @@ Status OlapTableSink::open(RuntimeState* state) { // The open() phase is mainly to generate DeltaWriter instances on the nodes corresponding to each node channel. // This phase will not fail due to a single tablet. // Therefore, if the open() phase fails, all tablets corresponding to the node need to be marked as failed. - index_channel->mark_as_failed(ch, fmt::format("{}, open failed, err: {}", ch->channel_info(), st.get_error_msg()), -1); + index_channel->mark_as_failed(ch, + fmt::format("{}, open failed, err: {}", + ch->channel_info(), st.get_error_msg()), + -1); } }); @@ -851,7 +863,8 @@ Status OlapTableSink::send(RuntimeState* state, RowBatch* input_batch) { SCOPED_RAW_TIMER(&_validate_data_ns); _filter_bitmap.Reset(batch->num_rows()); bool stop_processing = false; - RETURN_IF_ERROR(_validate_data(state, batch, &_filter_bitmap, &filtered_rows, &stop_processing)); + RETURN_IF_ERROR( + _validate_data(state, batch, &_filter_bitmap, &filtered_rows, &stop_processing)); _number_filtered_rows += filtered_rows; if (stop_processing) { // should be returned after updating "_number_filtered_rows", to make sure that load job can be cancelled @@ -870,12 +883,15 @@ Status OlapTableSink::send(RuntimeState* state, RowBatch* input_batch) { const OlapTablePartition* partition = nullptr; uint32_t dist_hash = 0; if (!_partition->find_tablet(tuple, &partition, &dist_hash)) { - RETURN_IF_ERROR(state->append_error_msg_to_file([]() -> std::string { return ""; }, + RETURN_IF_ERROR(state->append_error_msg_to_file( + []() -> std::string { return ""; }, [&]() -> std::string { - fmt::memory_buffer buf; - fmt::format_to(buf, "no partition for this tuple. tuple={}", Tuple::to_string(tuple, *_output_tuple_desc)); - return buf.data(); - }, &stop_processing)); + fmt::memory_buffer buf; + fmt::format_to(buf, "no partition for this tuple. tuple={}", + Tuple::to_string(tuple, *_output_tuple_desc)); + return buf.data(); + }, + &stop_processing)); _number_filtered_rows++; if (stop_processing) { return Status::EndOfFile("Encountered unqualified data, stop processing"); @@ -892,7 +908,7 @@ Status OlapTableSink::send(RuntimeState* state, RowBatch* input_batch) { } // check intolerable failure - for (auto index_channel : _channels) { + for (auto index_channel : _channels) { RETURN_IF_ERROR(index_channel->check_intolerable_failure()); } return Status::OK(); @@ -953,7 +969,6 @@ Status OlapTableSink::close(RuntimeState* state, Status close_status) { status = index_st; } } // end for index channels - } // TODO need to be improved LOG(INFO) << "total mem_exceeded_block_ns=" << mem_exceeded_block_ns @@ -1031,7 +1046,8 @@ Status OlapTableSink::_convert_batch(RuntimeState* state, RowBatch* input_batch, // Only when the expr return value is null, we will check the error message. std::string expr_error = _output_expr_ctxs[j]->get_error_msg(); if (!expr_error.empty()) { - RETURN_IF_ERROR(state->append_error_msg_to_file([&]() -> std::string { return slot_desc->col_name(); }, + RETURN_IF_ERROR(state->append_error_msg_to_file( + [&]() -> std::string { return slot_desc->col_name(); }, [&]() -> std::string { return expr_error; }, &stop_processing)); _number_filtered_rows++; ignore_this_row = true; @@ -1040,12 +1056,15 @@ Status OlapTableSink::_convert_batch(RuntimeState* state, RowBatch* input_batch, break; } if (!slot_desc->is_nullable()) { - RETURN_IF_ERROR(state->append_error_msg_to_file([]() -> std::string { return ""; }, + RETURN_IF_ERROR(state->append_error_msg_to_file( + []() -> std::string { return ""; }, [&]() -> std::string { - fmt::memory_buffer buf; - fmt::format_to(buf, "null value for not null column, column={}", slot_desc->col_name()); - return buf.data(); - }, &stop_processing)); + fmt::memory_buffer buf; + fmt::format_to(buf, "null value for not null column, column={}", + slot_desc->col_name()); + return buf.data(); + }, + &stop_processing)); _number_filtered_rows++; ignore_this_row = true; break; @@ -1073,8 +1092,8 @@ Status OlapTableSink::_convert_batch(RuntimeState* state, RowBatch* input_batch, return Status::OK(); } -Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitmap* filter_bitmap, int* filtered_rows, - bool* stop_processing) { +Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitmap* filter_bitmap, + int* filtered_rows, bool* stop_processing) { for (int row_no = 0; row_no < batch->num_rows(); ++row_no) { Tuple* tuple = batch->get_row(row_no)->get_tuple(0); bool row_valid = true; @@ -1083,8 +1102,9 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma SlotDescriptor* desc = _output_tuple_desc->slots()[i]; if (desc->is_nullable() && tuple->is_null(desc->null_indicator_offset())) { if (desc->type().type == TYPE_OBJECT) { - fmt::format_to(error_msg, "null is not allowed for bitmap column, column_name: {}; ", - desc->col_name()); + fmt::format_to(error_msg, + "null is not allowed for bitmap column, column_name: {}; ", + desc->col_name()); row_valid = false; } continue; @@ -1096,9 +1116,11 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma // Fixed length string StringValue* str_val = (StringValue*)slot; if (str_val->len > desc->type().len) { - fmt::format_to(error_msg, "{}", "the length of input is too long than schema. "); + fmt::format_to(error_msg, "{}", + "the length of input is too long than schema. "); fmt::format_to(error_msg, "column_name: {}; ", desc->col_name()); - fmt::format_to(error_msg, "input str: [{}] ", std::string(str_val->ptr, str_val->len)); + fmt::format_to(error_msg, "input str: [{}] ", + std::string(str_val->ptr, str_val->len)); fmt::format_to(error_msg, "schema length: {}; ", desc->type().len); fmt::format_to(error_msg, "actual length: {}; ", str_val->len); row_valid = false; @@ -1118,9 +1140,11 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma case TYPE_STRING: { StringValue* str_val = (StringValue*)slot; if (str_val->len > OLAP_STRING_MAX_LENGTH) { - fmt::format_to(error_msg, "{}", "the length of input is too long than schema. "); + fmt::format_to(error_msg, "{}", + "the length of input is too long than schema. "); fmt::format_to(error_msg, "column_name: {}; ", desc->col_name()); - fmt::format_to(error_msg, "first 128 bytes of input str: [{}] ", std::string(str_val->ptr, 128)); + fmt::format_to(error_msg, "first 128 bytes of input str: [{}] ", + std::string(str_val->ptr, 128)); fmt::format_to(error_msg, "schema length: {}; ", OLAP_STRING_MAX_LENGTH); fmt::format_to(error_msg, "actual length: {}; ", str_val->len); row_valid = false; @@ -1134,15 +1158,19 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma int code = dec_val.round(&dec_val, desc->type().scale, HALF_UP); reinterpret_cast<PackedInt128*>(slot)->value = dec_val.value(); if (code != E_DEC_OK) { - fmt::format_to(error_msg, "round one decimal failed.value={}; ", dec_val.to_string()); + fmt::format_to(error_msg, "round one decimal failed.value={}; ", + dec_val.to_string()); row_valid = false; continue; } } if (dec_val > _max_decimalv2_val[i] || dec_val < _min_decimalv2_val[i]) { - fmt::format_to(error_msg, "decimal value is not valid for definition, column={}", desc->col_name()); + fmt::format_to(error_msg, + "decimal value is not valid for definition, column={}", + desc->col_name()); fmt::format_to(error_msg, ", value={}", dec_val.to_string()); - fmt::format_to(error_msg, ", precision={}, scale={}; ", desc->type().precision, desc->type().scale); + fmt::format_to(error_msg, ", precision={}, scale={}; ", desc->type().precision, + desc->type().scale); row_valid = false; continue; } @@ -1151,7 +1179,9 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma case TYPE_HLL: { Slice* hll_val = (Slice*)slot; if (!HyperLogLog::is_valid(*hll_val)) { - fmt::format_to(error_msg, "Content of HLL type column is invalid. column name: {}; ", desc->col_name()); + fmt::format_to(error_msg, + "Content of HLL type column is invalid. column name: {}; ", + desc->col_name()); row_valid = false; continue; } @@ -1165,7 +1195,8 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma if (!row_valid) { (*filtered_rows)++; filter_bitmap->Set(row_no, true); - RETURN_IF_ERROR(state->append_error_msg_to_file([]() -> std::string { return ""; }, + RETURN_IF_ERROR(state->append_error_msg_to_file( + []() -> std::string { return ""; }, [&]() -> std::string { return error_msg.data(); }, stop_processing)); } } diff --git a/be/src/exprs/CMakeLists.txt b/be/src/exprs/CMakeLists.txt index 5d69aa7..3b4b86c 100644 --- a/be/src/exprs/CMakeLists.txt +++ b/be/src/exprs/CMakeLists.txt @@ -54,6 +54,7 @@ add_library(Exprs math_functions.cpp null_literal.cpp scalar_fn_call.cpp + rpc_fn_call.cpp slot_ref.cpp string_functions.cpp array_functions.cpp diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp index 97352da..1c29d5d 100644 --- a/be/src/exprs/expr.cpp +++ b/be/src/exprs/expr.cpp @@ -38,6 +38,7 @@ #include "exprs/is_null_predicate.h" #include "exprs/literal.h" #include "exprs/null_literal.h" +#include "exprs/rpc_fn_call.h" #include "exprs/scalar_fn_call.h" #include "exprs/slot_ref.h" #include "exprs/tuple_is_null_predicate.h" @@ -357,6 +358,8 @@ Status Expr::create_expr(ObjectPool* pool, const TExprNode& texpr_node, Expr** e *expr = pool->add(new IfNullExpr(texpr_node)); } else if (texpr_node.fn.name.function_name == "coalesce") { *expr = pool->add(new CoalesceExpr(texpr_node)); + } else if (texpr_node.fn.binary_type == TFunctionBinaryType::RPC) { + *expr = pool->add(new RPCFnCall(texpr_node)); } else { *expr = pool->add(new ScalarFnCall(texpr_node)); } diff --git a/be/src/exprs/expr_context.h b/be/src/exprs/expr_context.h index 45896a2..f176240 100644 --- a/be/src/exprs/expr_context.h +++ b/be/src/exprs/expr_context.h @@ -153,6 +153,7 @@ public: private: friend class Expr; friend class ScalarFnCall; + friend class RPCFnCall; friend class InPredicate; friend class RuntimePredicateWrapper; friend class BloomFilterPredicate; diff --git a/be/src/exprs/rpc_fn_call.cpp b/be/src/exprs/rpc_fn_call.cpp new file mode 100644 index 0000000..92b67e5 --- /dev/null +++ b/be/src/exprs/rpc_fn_call.cpp @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exprs/rpc_fn_call.h" + +#include "exprs/anyval_util.h" +#include "exprs/expr_context.h" +#include "fmt/format.h" +#include "gen_cpp/function_service.pb.h" +#include "runtime/runtime_state.h" +#include "runtime/user_function_cache.h" +#include "service/brpc.h" +#include "util/brpc_client_cache.h" + +namespace doris { + +RPCFnCall::RPCFnCall(const TExprNode& node) : Expr(node), _fn_context_index(-1) { + DCHECK_EQ(_fn.binary_type, TFunctionBinaryType::RPC); +} + +Status RPCFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ExprContext* context) { + RETURN_IF_ERROR(Expr::prepare(state, desc, context)); + DCHECK(!_fn.scalar_fn.symbol.empty()); + + FunctionContext::TypeDesc return_type = AnyValUtil::column_type_to_type_desc(_type); + std::vector<FunctionContext::TypeDesc> arg_types; + bool char_arg = false; + for (int i = 0; i < _children.size(); ++i) { + arg_types.push_back(AnyValUtil::column_type_to_type_desc(_children[i]->type())); + char_arg = char_arg || (_children[i]->type().type == TYPE_CHAR); + } + _fn_context_index = context->register_func(state, return_type, arg_types, 0); + + // _fn.scalar_fn.symbol + _rpc_function_symbol = _fn.scalar_fn.symbol; + + _client = state->exec_env()->brpc_function_client_cache()->get_client(_fn.hdfs_location); + + if (_client == nullptr) { + return Status::InternalError( + fmt::format("rpc env init error: {}/{}", _fn.hdfs_location, _rpc_function_symbol)); + } + return Status::OK(); +} + +Status RPCFnCall::open(RuntimeState* state, ExprContext* ctx, + FunctionContext::FunctionStateScope scope) { + RETURN_IF_ERROR(Expr::open(state, ctx, scope)); + return Status::OK(); +} + +void RPCFnCall::close(RuntimeState* state, ExprContext* context, + FunctionContext::FunctionStateScope scope) { + Expr::close(state, context, scope); +} + +Status RPCFnCall::_eval_children(ExprContext* context, TupleRow* row, + PFunctionCallResponse* response) { + PFunctionCallRequest request; + request.set_function_name(_rpc_function_symbol); + for (int i = 0; i < _children.size(); ++i) { + PValues* arg = request.add_args(); + void* src_slot = context->get_value(_children[i], row); + PGenericType* ptype = arg->mutable_type(); + if (src_slot == nullptr) { + arg->set_has_null(true); + arg->add_null_map(true); + } else { + arg->set_has_null(false); + } + switch (_children[i]->type().type) { + case TYPE_BOOLEAN: { + ptype->set_id(PGenericType::BOOLEAN); + arg->add_bool_value(*(bool*)src_slot); + break; + } + case TYPE_TINYINT: { + ptype->set_id(PGenericType::INT8); + arg->add_int32_value(*(int8_t*)src_slot); + break; + } + case TYPE_SMALLINT: { + ptype->set_id(PGenericType::INT16); + arg->add_int32_value(*(int16_t*)src_slot); + break; + } + case TYPE_INT: { + ptype->set_id(PGenericType::INT32); + arg->add_int32_value(*(int*)src_slot); + break; + } + case TYPE_BIGINT: { + ptype->set_id(PGenericType::INT64); + arg->add_int64_value(*(int64_t*)src_slot); + break; + } + case TYPE_LARGEINT: { + ptype->set_id(PGenericType::INT128); + char buffer[sizeof(__int128)]; + memcpy(buffer, src_slot, sizeof(__int128)); + arg->add_bytes_value(buffer, sizeof(__int128)); + break; + } + case TYPE_DOUBLE: { + ptype->set_id(PGenericType::DOUBLE); + arg->add_double_value(*(double*)src_slot); + break; + } + case TYPE_FLOAT: { + ptype->set_id(PGenericType::FLOAT); + arg->add_float_value(*(float*)src_slot); + break; + } + case TYPE_VARCHAR: + case TYPE_STRING: + case TYPE_CHAR: { + ptype->set_id(PGenericType::STRING); + StringValue value = *reinterpret_cast<StringValue*>(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_HLL: { + ptype->set_id(PGenericType::HLL); + StringValue value = *reinterpret_cast<StringValue*>(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_OBJECT: { + ptype->set_id(PGenericType::BITMAP); + StringValue value = *reinterpret_cast<StringValue*>(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_DECIMALV2: { + ptype->set_id(PGenericType::DECIMAL128); + ptype->mutable_decimal_type()->set_precision(_children[i]->type().precision); + ptype->mutable_decimal_type()->set_scale(_children[i]->type().scale); + char buffer[sizeof(__int128)]; + memcpy(buffer, src_slot, sizeof(__int128)); + arg->add_bytes_value(buffer, sizeof(__int128)); + break; + } + case TYPE_DATE: { + ptype->set_id(PGenericType::DATE); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_day(time_val->day()); + date_time->set_month(time_val->month()); + date_time->set_year(time_val->year()); + break; + } + case TYPE_DATETIME: { + ptype->set_id(PGenericType::DATETIME); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_day(time_val->day()); + date_time->set_month(time_val->month()); + date_time->set_year(time_val->year()); + date_time->set_hour(time_val->hour()); + date_time->set_minute(time_val->minute()); + date_time->set_second(time_val->second()); + date_time->set_microsecond(time_val->microsecond()); + break; + } + case TYPE_TIME: { + ptype->set_id(PGenericType::DATETIME); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_hour(time_val->hour()); + date_time->set_minute(time_val->minute()); + date_time->set_second(time_val->second()); + date_time->set_microsecond(time_val->microsecond()); + break; + } + default: { + FunctionContext* fn_ctx = context->fn_context(_fn_context_index); + fn_ctx->set_error( + fmt::format("data time not supported: {}", _children[i]->type().type).c_str()); + break; + } + } + } + + brpc::Controller cntl; + _client->fn_call(&cntl, &request, response, nullptr); + if (cntl.Failed()) { + FunctionContext* fn_ctx = context->fn_context(_fn_context_index); + fn_ctx->set_error(cntl.ErrorText().c_str()); + return Status::InternalError(fmt::format("call rpc function {} failed: {}", + _rpc_function_symbol, cntl.ErrorText()) + .c_str()); + } + if (response->status().status_code() != 0) { + FunctionContext* fn_ctx = context->fn_context(_fn_context_index); + fn_ctx->set_error(response->status().DebugString().c_str()); + return Status::InternalError(fmt::format("call rpc function {} failed: {}", + _rpc_function_symbol, + response->status().DebugString())); + } + return Status::OK(); +} + +template <typename T> +T RPCFnCall::interpret_eval(ExprContext* context, TupleRow* row) { + PFunctionCallResponse response; + Status st = _eval_children(context, row, &response); + WARN_IF_ERROR(st, "call rpc udf error"); + if (!st.ok() || (response.result().has_null() && response.result().null_map(0))) { + return T::null(); + } + T res_val; + // TODO(yangzhg) deal with udtf and udaf + const PValues& result = response.result(); + if constexpr (std::is_same_v<T, TinyIntVal>) { + DCHECK(result.type().id() == PGenericType::INT8); + res_val.val = static_cast<int8_t>(result.int32_value(0)); + } else if constexpr (std::is_same_v<T, SmallIntVal>) { + DCHECK(result.type().id() == PGenericType::INT16); + res_val.val = static_cast<int16_t>(result.int32_value(0)); + } else if constexpr (std::is_same_v<T, IntVal>) { + DCHECK(result.type().id() == PGenericType::INT32); + res_val.val = result.int32_value(0); + } else if constexpr (std::is_same_v<T, BigIntVal>) { + DCHECK(result.type().id() == PGenericType::INT64); + res_val.val = result.int64_value(0); + } else if constexpr (std::is_same_v<T, FloatVal>) { + DCHECK(result.type().id() == PGenericType::FLOAT); + res_val.val = result.float_value(0); + } else if constexpr (std::is_same_v<T, DoubleVal>) { + DCHECK(result.type().id() == PGenericType::DOUBLE); + res_val.val = result.double_value(0); + } else if constexpr (std::is_same_v<T, StringVal>) { + DCHECK(result.type().id() == PGenericType::STRING); + FunctionContext* fn_ctx = context->fn_context(_fn_context_index); + StringVal val(fn_ctx, result.string_value(0).size()); + res_val = val.copy_from(fn_ctx, + reinterpret_cast<const uint8_t*>(result.string_value(0).c_str()), + result.string_value(0).size()); + } else if constexpr (std::is_same_v<T, LargeIntVal>) { + DCHECK(result.type().id() == PGenericType::INT128); + memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t)); + } else if constexpr (std::is_same_v<T, DateTimeVal>) { + DCHECK(result.type().id() == PGenericType::DATE || + result.type().id() == PGenericType::DATETIME); + DateTimeValue value; + value.set_time(result.datetime_value(0).year(), result.datetime_value(0).month(), + result.datetime_value(0).day(), result.datetime_value(0).hour(), + result.datetime_value(0).minute(), result.datetime_value(0).second(), + result.datetime_value(0).microsecond()); + if (result.type().id() == PGenericType::DATE) { + value.set_type(TimeType::TIME_DATE); + } else if (result.type().id() == PGenericType::DATETIME) { + if (result.datetime_value(0).has_year()) { + value.set_type(TimeType::TIME_DATETIME); + } else + value.set_type(TimeType::TIME_TIME); + } + value.to_datetime_val(&res_val); + } else if constexpr (std::is_same_v<T, DecimalV2Val>) { + DCHECK(result.type().id() == PGenericType::DECIMAL128); + memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t)); + } + return res_val; +} // namespace doris + +doris_udf::IntVal RPCFnCall::get_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval<IntVal>(context, row); +} + +doris_udf::BooleanVal RPCFnCall::get_boolean_val(ExprContext* context, TupleRow* row) { + return interpret_eval<BooleanVal>(context, row); +} + +doris_udf::TinyIntVal RPCFnCall::get_tiny_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval<TinyIntVal>(context, row); +} + +doris_udf::SmallIntVal RPCFnCall::get_small_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval<SmallIntVal>(context, row); +} + +doris_udf::BigIntVal RPCFnCall::get_big_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval<BigIntVal>(context, row); +} + +doris_udf::FloatVal RPCFnCall::get_float_val(ExprContext* context, TupleRow* row) { + return interpret_eval<FloatVal>(context, row); +} + +doris_udf::DoubleVal RPCFnCall::get_double_val(ExprContext* context, TupleRow* row) { + return interpret_eval<DoubleVal>(context, row); +} + +doris_udf::StringVal RPCFnCall::get_string_val(ExprContext* context, TupleRow* row) { + return interpret_eval<StringVal>(context, row); +} + +doris_udf::LargeIntVal RPCFnCall::get_large_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval<LargeIntVal>(context, row); +} + +doris_udf::DateTimeVal RPCFnCall::get_datetime_val(ExprContext* context, TupleRow* row) { + return interpret_eval<DateTimeVal>(context, row); +} + +doris_udf::DecimalV2Val RPCFnCall::get_decimalv2_val(ExprContext* context, TupleRow* row) { + return interpret_eval<DecimalV2Val>(context, row); +} +doris_udf::CollectionVal RPCFnCall::get_array_val(ExprContext* context, TupleRow* row) { + return interpret_eval<CollectionVal>(context, row); +} + +} // namespace doris diff --git a/be/src/exprs/rpc_fn_call.h b/be/src/exprs/rpc_fn_call.h new file mode 100644 index 0000000..c04e2ec --- /dev/null +++ b/be/src/exprs/rpc_fn_call.h @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "common/object_pool.h" +#include "exprs/expr.h" +#include "udf/udf.h" + +namespace doris { +class TExprNode; +class PFunctionService_Stub; +class PFunctionCallResponse; + +class RPCFnCall : public Expr { +public: + RPCFnCall(const TExprNode& node); + + virtual Status prepare(RuntimeState* state, const RowDescriptor& desc, + ExprContext* context) override; + virtual Status open(RuntimeState* state, ExprContext* context, + FunctionContext::FunctionStateScope scope) override; + virtual void close(RuntimeState* state, ExprContext* context, + FunctionContext::FunctionStateScope scope) override; + virtual Expr* clone(ObjectPool* pool) const override { return pool->add(new RPCFnCall(*this)); } + + virtual doris_udf::BooleanVal get_boolean_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::TinyIntVal get_tiny_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::SmallIntVal get_small_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::IntVal get_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::BigIntVal get_big_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::FloatVal get_float_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::DoubleVal get_double_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::StringVal get_string_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::DateTimeVal get_datetime_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::CollectionVal get_array_val(ExprContext* context, TupleRow*) override; + +private: + Status _eval_children(ExprContext* context, TupleRow* row, PFunctionCallResponse* response); + template <typename RETURN_TYPE> + RETURN_TYPE interpret_eval(ExprContext* context, TupleRow* row); + + std::shared_ptr<PFunctionService_Stub> _client = nullptr; + int _fn_context_index; + std::string _rpc_function_symbol; +}; +}; // namespace doris \ No newline at end of file diff --git a/be/src/exprs/runtime_filter_rpc.cpp b/be/src/exprs/runtime_filter_rpc.cpp index c20779d..764dcf9 100644 --- a/be/src/exprs/runtime_filter_rpc.cpp +++ b/be/src/exprs/runtime_filter_rpc.cpp @@ -25,7 +25,7 @@ #include "gen_cpp/PlanNodes_types.h" #include "gen_cpp/internal_service.pb.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" namespace doris { @@ -40,7 +40,7 @@ Status IRuntimeFilter::push_to_remote(RuntimeState* state, const TNetworkAddress DCHECK(is_producer()); DCHECK(_rpc_context == nullptr); std::shared_ptr<PBackendService_Stub> stub( - state->exec_env()->brpc_stub_cache()->get_stub(*addr)); + state->exec_env()->brpc_internal_client_cache()->get_client(*addr)); if (!stub) { std::string msg = fmt::format("Get rpc stub failed, host={}, port=", addr->hostname, addr->port); @@ -94,7 +94,8 @@ Status IRuntimeFilter::join_rpc() { if (_rpc_context->cntl.Failed()) { LOG(WARNING) << "runtimefilter rpc err:" << _rpc_context->cntl.ErrorText(); // reset stub cache - ExecEnv::GetInstance()->brpc_stub_cache()->erase(_rpc_context->cntl.remote_side()); + ExecEnv::GetInstance()->brpc_internal_client_cache()->erase( + _rpc_context->cntl.remote_side()); } } return Status::OK(); diff --git a/be/src/gen_cpp/CMakeLists.txt b/be/src/gen_cpp/CMakeLists.txt index cc6d52b..22aa8c9 100644 --- a/be/src/gen_cpp/CMakeLists.txt +++ b/be/src/gen_cpp/CMakeLists.txt @@ -84,8 +84,8 @@ set(SRC_FILES ${GEN_CPP_DIR}/data.pb.cc ${GEN_CPP_DIR}/descriptors.pb.cc ${GEN_CPP_DIR}/internal_service.pb.cc + ${GEN_CPP_DIR}/function_service.pb.cc ${GEN_CPP_DIR}/types.pb.cc - ${GEN_CPP_DIR}/status.pb.cc ${GEN_CPP_DIR}/segment_v2.pb.cc #$${GEN_CPP_DIR}/opcode/functions.cc #$${GEN_CPP_DIR}/opcode/vector-functions.cc diff --git a/be/src/http/action/check_rpc_channel_action.cpp b/be/src/http/action/check_rpc_channel_action.cpp index a26031f..6a688e8 100644 --- a/be/src/http/action/check_rpc_channel_action.cpp +++ b/be/src/http/action/check_rpc_channel_action.cpp @@ -24,7 +24,7 @@ #include "http/http_request.h" #include "runtime/exec_env.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/md5.h" namespace doris { @@ -71,7 +71,7 @@ void CheckRPCChannelAction::handle(HttpRequest* req) { digest.digest(); request.set_md5(digest.hex()); std::shared_ptr<PBackendService_Stub> stub( - _exec_env->brpc_stub_cache()->get_stub(req_ip, port)); + _exec_env->brpc_internal_client_cache()->get_client(req_ip, port)); if (!stub) { HttpChannel::send_reply( req, HttpStatus::INTERNAL_SERVER_ERROR, diff --git a/be/src/http/action/reset_rpc_channel_action.cpp b/be/src/http/action/reset_rpc_channel_action.cpp index 38e4a7e..242bfe7 100644 --- a/be/src/http/action/reset_rpc_channel_action.cpp +++ b/be/src/http/action/reset_rpc_channel_action.cpp @@ -22,7 +22,7 @@ #include "http/http_channel.h" #include "http/http_request.h" #include "runtime/exec_env.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/string_util.h" namespace doris { @@ -30,11 +30,11 @@ ResetRPCChannelAction::ResetRPCChannelAction(ExecEnv* exec_env) : _exec_env(exec void ResetRPCChannelAction::handle(HttpRequest* req) { std::string endpoints = req->param("endpoints"); if (iequal(endpoints, "all")) { - int size = _exec_env->brpc_stub_cache()->size(); + int size = _exec_env->brpc_internal_client_cache()->size(); if (size > 0) { std::vector<std::string> endpoints; - _exec_env->brpc_stub_cache()->get_all(&endpoints); - _exec_env->brpc_stub_cache()->clear(); + _exec_env->brpc_internal_client_cache()->get_all(&endpoints); + _exec_env->brpc_internal_client_cache()->clear(); HttpChannel::send_reply(req, HttpStatus::OK, fmt::format("reseted: {0}", join(endpoints, ","))); return; @@ -45,14 +45,14 @@ void ResetRPCChannelAction::handle(HttpRequest* req) { } else { std::vector<std::string> reseted; for (const std::string& endpoint : split(endpoints, ",")) { - if (!_exec_env->brpc_stub_cache()->exist(endpoint)) { + if (!_exec_env->brpc_internal_client_cache()->exist(endpoint)) { std::string err = fmt::format("{0}: not found.", endpoint); LOG(WARNING) << err; HttpChannel::send_reply(req, HttpStatus::INTERNAL_SERVER_ERROR, err); return; } - if (_exec_env->brpc_stub_cache()->erase(endpoint)) { + if (_exec_env->brpc_internal_client_cache()->erase(endpoint)) { reseted.push_back(endpoint); } else { std::string err = fmt::format("{0}: reset failed.", endpoint); diff --git a/be/src/plugin/plugin_loader.cpp b/be/src/plugin/plugin_loader.cpp index 1e2876d..a0d0674 100644 --- a/be/src/plugin/plugin_loader.cpp +++ b/be/src/plugin/plugin_loader.cpp @@ -58,13 +58,13 @@ Status DynamicPluginLoader::install() { // no, need download zip install PluginZip zip(_source); - RETURN_IF_ERROR(zip.extract(_install_path, _name)); + RETURN_NOT_OK_STATUS_WITH_WARN(zip.extract(_install_path, _name), "plugin install failed"); } // open plugin - RETURN_IF_ERROR(open_plugin()); + RETURN_NOT_OK_STATUS_WITH_WARN(open_plugin(), "plugin install failed"); - RETURN_IF_ERROR(open_valid()); + RETURN_NOT_OK_STATUS_WITH_WARN(open_valid(), "plugin install failed"); // plugin init // todo: what should be send? diff --git a/be/src/runtime/data_stream_sender.cpp b/be/src/runtime/data_stream_sender.cpp index 99a08b0..681f5fc 100644 --- a/be/src/runtime/data_stream_sender.cpp +++ b/be/src/runtime/data_stream_sender.cpp @@ -42,7 +42,7 @@ #include "runtime/tuple_row.h" #include "service/backend_options.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/debug_util.h" #include "util/defer_op.h" #include "util/network_util.h" @@ -112,7 +112,7 @@ Status DataStreamSender::Channel::init(RuntimeState* state) { // so the empty channel not need call function close_internal() _need_close = (_fragment_instance_id.hi != -1 && _fragment_instance_id.lo != -1); if (_need_close) { - _brpc_stub = state->exec_env()->brpc_stub_cache()->get_stub(_brpc_dest_addr); + _brpc_stub = state->exec_env()->brpc_internal_client_cache()->get_client(_brpc_dest_addr); if (!_brpc_stub) { std::string msg = fmt::format("Get rpc stub failed, dest_addr={}:{}", _brpc_dest_addr.hostname, _brpc_dest_addr.port); diff --git a/be/src/runtime/exec_env.h b/be/src/runtime/exec_env.h index 0b51e47..8c8a9fb 100644 --- a/be/src/runtime/exec_env.h +++ b/be/src/runtime/exec_env.h @@ -28,7 +28,10 @@ class VDataStreamMgr; } class BfdParser; class BrokerMgr; -class BrpcStubCache; + +template <class T> +class BrpcClientCache; + class BufferPool; class CgroupsMgr; class DataStreamMgr; @@ -61,8 +64,12 @@ class BackendServiceClient; class FrontendServiceClient; class TPaloBrokerServiceClient; class TExtDataSourceServiceClient; +class PBackendService_Stub; +class PFunctionService_Stub; + template <class T> class ClientCache; + class HeartbeatFlags; // Execution environment for queries/plan fragments. @@ -126,7 +133,12 @@ public: TmpFileMgr* tmp_file_mgr() { return _tmp_file_mgr; } BfdParser* bfd_parser() const { return _bfd_parser; } BrokerMgr* broker_mgr() const { return _broker_mgr; } - BrpcStubCache* brpc_stub_cache() const { return _brpc_stub_cache; } + BrpcClientCache<PBackendService_Stub>* brpc_internal_client_cache() const { + return _internal_client_cache; + } + BrpcClientCache<PFunctionService_Stub>* brpc_function_client_cache() const { + return _function_client_cache; + } ReservationTracker* buffer_reservation() { return _buffer_reservation; } BufferPool* buffer_pool() { return _buffer_pool; } LoadChannelMgr* load_channel_mgr() { return _load_channel_mgr; } @@ -180,7 +192,7 @@ private: // Scanner threads for common queries will use this thread pool, // and the priority of each scan task is set according to the size of the query. - // _limited_scan_thread_pool is also the thread pool used for scanner. + // _limited_scan_thread_pool is also the thread pool used for scanner. // The difference is that it is no longer a priority queue, but according to the concurrency // set by the user to control the number of threads that can be used by a query. @@ -203,7 +215,8 @@ private: BrokerMgr* _broker_mgr = nullptr; LoadChannelMgr* _load_channel_mgr = nullptr; LoadStreamMgr* _load_stream_mgr = nullptr; - BrpcStubCache* _brpc_stub_cache = nullptr; + BrpcClientCache<PBackendService_Stub>* _internal_client_cache = nullptr; + BrpcClientCache<PFunctionService_Stub>* _function_client_cache = nullptr; ReservationTracker* _buffer_reservation = nullptr; BufferPool* _buffer_pool = nullptr; diff --git a/be/src/runtime/exec_env_init.cpp b/be/src/runtime/exec_env_init.cpp index 35630d0..128f52e 100644 --- a/be/src/runtime/exec_env_init.cpp +++ b/be/src/runtime/exec_env_init.cpp @@ -54,7 +54,7 @@ #include "runtime/thread_resource_mgr.h" #include "runtime/tmp_file_mgr.h" #include "util/bfd_parser.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/debug_util.h" #include "util/doris_metrics.h" #include "util/mem_info.h" @@ -125,7 +125,8 @@ Status ExecEnv::_init(const std::vector<StorePath>& store_paths) { _broker_mgr = new BrokerMgr(this); _load_channel_mgr = new LoadChannelMgr(); _load_stream_mgr = new LoadStreamMgr(); - _brpc_stub_cache = new BrpcStubCache(); + _internal_client_cache = new BrpcClientCache<PBackendService_Stub>(); + _function_client_cache = new BrpcClientCache<PFunctionService_Stub>(); _stream_load_executor = new StreamLoadExecutor(this); _routine_load_task_executor = new RoutineLoadTaskExecutor(this); _small_file_mgr = new SmallFileMgr(this, config::small_file_dir); @@ -285,7 +286,8 @@ void ExecEnv::_destroy() { return; } _deregister_metrics(); - SAFE_DELETE(_brpc_stub_cache); + SAFE_DELETE(_internal_client_cache); + SAFE_DELETE(_function_client_cache); SAFE_DELETE(_load_stream_mgr); SAFE_DELETE(_load_channel_mgr); SAFE_DELETE(_broker_mgr); diff --git a/be/src/runtime/runtime_filter_mgr.cpp b/be/src/runtime/runtime_filter_mgr.cpp index e7b3a0c..b5302ae 100644 --- a/be/src/runtime/runtime_filter_mgr.cpp +++ b/be/src/runtime/runtime_filter_mgr.cpp @@ -28,7 +28,7 @@ #include "runtime/runtime_filter_mgr.h" #include "runtime/runtime_state.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/time.h" namespace doris { @@ -251,7 +251,7 @@ Status RuntimeFilterMergeControllerEntity::merge(const PMergeFilterRequest* requ request_fragment_id->set_lo(targets[i].target_fragment_instance_id.lo); std::shared_ptr<PBackendService_Stub> stub( - ExecEnv::GetInstance()->brpc_stub_cache()->get_stub( + ExecEnv::GetInstance()->brpc_internal_client_cache()->get_client( targets[i].target_fragment_instance_addr)); VLOG_NOTICE << "send filter " << rpc_contexts[i]->request.filter_id() << " to:" << targets[i].target_fragment_instance_addr.hostname << ":" diff --git a/be/src/service/internal_service.cpp b/be/src/service/internal_service.cpp index a948db8..7cf7b28 100644 --- a/be/src/service/internal_service.cpp +++ b/be/src/service/internal_service.cpp @@ -30,7 +30,7 @@ #include "runtime/routine_load/routine_load_task_executor.h" #include "runtime/runtime_state.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/md5.h" #include "util/proto_util.h" #include "util/string_util.h" @@ -476,21 +476,21 @@ void PInternalServiceImpl<T>::reset_rpc_channel(google::protobuf::RpcController* brpc::ClosureGuard closure_guard(done); response->mutable_status()->set_status_code(0); if (request->all()) { - int size = ExecEnv::GetInstance()->brpc_stub_cache()->size(); + int size = ExecEnv::GetInstance()->brpc_internal_client_cache()->size(); if (size > 0) { std::vector<std::string> endpoints; - ExecEnv::GetInstance()->brpc_stub_cache()->get_all(&endpoints); - ExecEnv::GetInstance()->brpc_stub_cache()->clear(); + ExecEnv::GetInstance()->brpc_internal_client_cache()->get_all(&endpoints); + ExecEnv::GetInstance()->brpc_internal_client_cache()->clear(); *response->mutable_channels() = {endpoints.begin(), endpoints.end()}; } } else { for (const std::string& endpoint : request->endpoints()) { - if (!ExecEnv::GetInstance()->brpc_stub_cache()->exist(endpoint)) { + if (!ExecEnv::GetInstance()->brpc_internal_client_cache()->exist(endpoint)) { response->mutable_status()->add_error_msgs(endpoint + ": not found."); continue; } - if (ExecEnv::GetInstance()->brpc_stub_cache()->erase(endpoint)) { + if (ExecEnv::GetInstance()->brpc_internal_client_cache()->erase(endpoint)) { response->add_channels(endpoint); } else { response->mutable_status()->add_error_msgs(endpoint + ": reset failed."); diff --git a/be/src/udf/udf.cpp b/be/src/udf/udf.cpp index ce88fd6..b9ec504 100644 --- a/be/src/udf/udf.cpp +++ b/be/src/udf/udf.cpp @@ -23,6 +23,7 @@ #include <sstream> #include "common/logging.h" +#include "gen_cpp/types.pb.h" #include "olap/hll.h" #include "runtime/decimalv2_value.h" @@ -196,6 +197,20 @@ FunctionContext* FunctionContextImpl::clone(MemPool* pool) { return new_context; } +// TODO: to be implemented +void FunctionContextImpl::serialize(PFunctionContext* pcontext) const { + // pcontext->set_string_result(_string_result); + // pcontext->set_num_updates(_num_updates); + // pcontext->set_num_removes(_num_removes); + // pcontext->set_num_warnings(_num_warnings); + // pcontext->set_error_msg(_error_msg); + // PUniqueId* query_id = pcontext->mutable_query_id(); + // query_id->set_hi(_context->query_id().hi); + // query_id->set_lo(_context->query_id().lo); +} + +void FunctionContextImpl::derialize(const PFunctionContext& pcontext) {} + } // namespace doris namespace doris_udf { diff --git a/be/src/udf/udf_internal.h b/be/src/udf/udf_internal.h index 085002d..36cf8ad 100644 --- a/be/src/udf/udf_internal.h +++ b/be/src/udf/udf_internal.h @@ -33,6 +33,7 @@ class FreePool; class MemPool; class RuntimeState; class ColumnPtrWrapper; +class PFunctionContext; // This class actually implements the interface of FunctionContext. This is split to // hide the details from the external header. @@ -107,6 +108,9 @@ public: const doris_udf::FunctionContext::TypeDesc& get_return_type() const { return _return_type; } + void serialize(PFunctionContext* pcontext) const; + void derialize(const PFunctionContext& pcontext); + private: friend class doris_udf::FunctionContext; friend class ExprContext; diff --git a/be/src/util/CMakeLists.txt b/be/src/util/CMakeLists.txt index a5f2448..0582c57 100644 --- a/be/src/util/CMakeLists.txt +++ b/be/src/util/CMakeLists.txt @@ -100,7 +100,7 @@ set(UTIL_FILES timezone_utils.cpp easy_json.cc mustache/mustache.cc - brpc_stub_cache.cpp + brpc_client_cache.cpp zlib.cpp pprof_utils.cpp s3_uri.cpp diff --git a/be/src/util/brpc_stub_cache.cpp b/be/src/util/brpc_client_cache.cpp similarity index 64% rename from be/src/util/brpc_stub_cache.cpp rename to be/src/util/brpc_client_cache.cpp index b62f34a..df89585 100644 --- a/be/src/util/brpc_stub_cache.cpp +++ b/be/src/util/brpc_client_cache.cpp @@ -15,17 +15,31 @@ // specific language governing permissions and limitations // under the License. -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" namespace doris { DEFINE_GAUGE_METRIC_PROTOTYPE_2ARG(brpc_endpoint_stub_count, MetricUnit::NOUNIT); -BrpcStubCache::BrpcStubCache() { +DEFINE_GAUGE_METRIC_PROTOTYPE_2ARG(brpc_function_endpoint_stub_count, MetricUnit::NOUNIT); + +template <> +BrpcClientCache<PBackendService_Stub>::BrpcClientCache() { REGISTER_HOOK_METRIC(brpc_endpoint_stub_count, [this]() { return _stub_map.size(); }); } -BrpcStubCache::~BrpcStubCache() { +template <> +BrpcClientCache<PBackendService_Stub>::~BrpcClientCache() { DEREGISTER_HOOK_METRIC(brpc_endpoint_stub_count); } + +template <> +BrpcClientCache<PFunctionService_Stub>::BrpcClientCache() { + REGISTER_HOOK_METRIC(brpc_function_endpoint_stub_count, [this]() { return _stub_map.size(); }); +} + +template <> +BrpcClientCache<PFunctionService_Stub>::~BrpcClientCache() { + DEREGISTER_HOOK_METRIC(brpc_function_endpoint_stub_count); +} } // namespace doris diff --git a/be/src/util/brpc_client_cache.h b/be/src/util/brpc_client_cache.h new file mode 100644 index 0000000..f310cd1 --- /dev/null +++ b/be/src/util/brpc_client_cache.h @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <parallel_hashmap/phmap.h> + +#include <memory> +#include <mutex> + +#include "common/config.h" +#include "gen_cpp/Types_types.h" // TNetworkAddress +#include "gen_cpp/function_service.pb.h" +#include "gen_cpp/internal_service.pb.h" +#include "service/brpc.h" +#include "util/doris_metrics.h" + +template <typename T> +using SubMap = phmap::parallel_flat_hash_map< + std::string, std::shared_ptr<T>, std::hash<std::string>, std::equal_to<std::string>, + std::allocator<std::pair<const std::string, std::shared_ptr<T>>>, 8, std::mutex>; +namespace doris { + +template <class T> +class BrpcClientCache { +public: + BrpcClientCache(); + virtual ~BrpcClientCache(); + + inline std::shared_ptr<T> get_client(const butil::EndPoint& endpoint) { + return get_client(butil::endpoint2str(endpoint).c_str()); + } + +#ifdef BE_TEST + virtual inline std::shared_ptr<T> get_client(const TNetworkAddress& taddr) { + std::string host_port = fmt::format("{}:{}", taddr.hostname, taddr.port); + return get_client(host_port); + } +#else + inline std::shared_ptr<T> get_client(const TNetworkAddress& taddr) { + std::string host_port = fmt::format("{}:{}", taddr.hostname, taddr.port); + return get_client(host_port); + } +#endif + + inline std::shared_ptr<T> get_client(const std::string& host, int port) { + std::string host_port = fmt::format("{}:{}", host, port); + return get_client(host_port); + } + + inline std::shared_ptr<T> get_client(const std::string& host_port) { + auto stub_ptr = _stub_map.find(host_port); + if (LIKELY(stub_ptr != _stub_map.end())) { + return stub_ptr->second; + } + // new one stub and insert into map + brpc::ChannelOptions options; + if constexpr (std::is_same_v<T, PFunctionService_Stub>) { + options.protocol = config::function_service_protocol; + } + std::unique_ptr<brpc::Channel> channel(new brpc::Channel()); + int ret_code = 0; + if (host_port.find("://") == std::string::npos) { + ret_code = channel->Init(host_port.c_str(), &options); + } else { + ret_code = + channel->Init(host_port.c_str(), config::rpc_load_balancer.c_str(), &options); + } + if (ret_code) { + return nullptr; + } + auto stub = std::make_shared<T>(channel.release(), + google::protobuf::Service::STUB_OWNS_CHANNEL); + _stub_map[host_port] = stub; + return stub; + } + + inline size_t size() { return _stub_map.size(); } + + inline void clear() { _stub_map.clear(); } + + inline size_t erase(const std::string& host_port) { return _stub_map.erase(host_port); } + + size_t erase(const std::string& host, int port) { + std::string host_port = fmt::format("{}:{}", host, port); + return erase(host_port); + } + + inline size_t erase(const butil::EndPoint& endpoint) { + return _stub_map.erase(butil::endpoint2str(endpoint).c_str()); + } + + inline bool exist(const std::string& host_port) { + return _stub_map.find(host_port) != _stub_map.end(); + } + + inline void get_all(std::vector<std::string>* endpoints) { + for (auto it = _stub_map.begin(); it != _stub_map.end(); ++it) { + endpoints->emplace_back(it->first.c_str()); + } + } + + inline bool available(std::shared_ptr<T> stub, const butil::EndPoint& endpoint) { + return available(stub, butil::endpoint2str(endpoint).c_str()); + } + + inline bool available(std::shared_ptr<T> stub, const std::string& host_port) { + if (!stub) { + LOG(WARNING) << "stub is null to: " << host_port; + return false; + } + PHandShakeRequest request; + PHandShakeResponse response; + brpc::Controller cntl; + stub->hand_shake(&cntl, &request, &response, nullptr); + if (!cntl.Failed()) { + return true; + } else { + LOG(WARNING) << "open brpc connection to " << host_port + << " failed: " << cntl.ErrorText(); + return false; + } + } + + inline bool available(std::shared_ptr<T> stub, const std::string& host, int port) { + std::string host_port = fmt::format("{}:{}", host, port); + return available(stub, host_port); + } + +private: + SubMap<T> _stub_map; +}; + +using InternalServiceClientCache = BrpcClientCache<PBackendService_Stub>; +using FunctionServiceClientCache = BrpcClientCache<PFunctionService_Stub>; +} // namespace doris diff --git a/be/src/util/brpc_stub_cache.h b/be/src/util/brpc_stub_cache.h deleted file mode 100644 index 21800f3..0000000 --- a/be/src/util/brpc_stub_cache.h +++ /dev/null @@ -1,159 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include <parallel_hashmap/phmap.h> - -#include <memory> -#include <mutex> - -#include "common/config.h" -#include "gen_cpp/Types_types.h" // TNetworkAddress -#include "gen_cpp/internal_service.pb.h" -#include "service/brpc.h" -#include "util/doris_metrics.h" - -namespace std { -template <> -struct hash<butil::EndPoint> { - std::size_t operator()(butil::EndPoint const& p) const { - return phmap::HashState().combine(0, butil::ip2int(p.ip), p.port); - } -}; -} // namespace std -using SubMap = phmap::parallel_flat_hash_map< - butil::EndPoint, std::shared_ptr<doris::PBackendService_Stub>, std::hash<butil::EndPoint>, - std::equal_to<butil::EndPoint>, - std::allocator< - std::pair<const butil::EndPoint, std::shared_ptr<doris::PBackendService_Stub>>>, - 8, std::mutex>; -namespace doris { - -class BrpcStubCache { -public: - BrpcStubCache(); - virtual ~BrpcStubCache(); - - inline std::shared_ptr<PBackendService_Stub> get_stub(const butil::EndPoint& endpoint) { - auto stub_ptr = _stub_map.find(endpoint); - if (LIKELY(stub_ptr != _stub_map.end())) { - return stub_ptr->second; - } - // new one stub and insert into map - brpc::ChannelOptions options; - std::unique_ptr<brpc::Channel> channel(new brpc::Channel()); - if (channel->Init(endpoint, &options)) { - return nullptr; - } - auto stub = std::make_shared<PBackendService_Stub>( - channel.release(), google::protobuf::Service::STUB_OWNS_CHANNEL); - _stub_map[endpoint] = stub; - return stub; - } - - virtual std::shared_ptr<PBackendService_Stub> get_stub(const TNetworkAddress& taddr) { - butil::EndPoint endpoint; - if (str2endpoint(taddr.hostname.c_str(), taddr.port, &endpoint)) { - LOG(WARNING) << "unknown endpoint, hostname=" << taddr.hostname - << ", port=" << taddr.port; - return nullptr; - } - return get_stub(endpoint); - } - - inline std::shared_ptr<PBackendService_Stub> get_stub(const std::string& host, int port) { - butil::EndPoint endpoint; - if (str2endpoint(host.c_str(), port, &endpoint)) { - LOG(WARNING) << "unknown endpoint, hostname=" << host << ", port=" << port; - return nullptr; - } - return get_stub(endpoint); - } - - inline size_t size() { return _stub_map.size(); } - - inline void clear() { _stub_map.clear(); } - - inline size_t erase(const std::string& host_port) { - butil::EndPoint endpoint; - if (str2endpoint(host_port.c_str(), &endpoint)) { - LOG(WARNING) << "unknown endpoint: " << host_port; - return 0; - } - return erase(endpoint); - } - - size_t erase(const std::string& host, int port) { - butil::EndPoint endpoint; - if (str2endpoint(host.c_str(), port, &endpoint)) { - LOG(WARNING) << "unknown endpoint, hostname=" << host << ", port=" << port; - return 0; - } - return erase(endpoint); - } - - inline size_t erase(const butil::EndPoint& endpoint) { return _stub_map.erase(endpoint); } - - inline bool exist(const std::string& host_port) { - butil::EndPoint endpoint; - if (str2endpoint(host_port.c_str(), &endpoint)) { - LOG(WARNING) << "unknown endpoint: " << host_port; - return false; - } - return _stub_map.find(endpoint) != _stub_map.end(); - } - - inline void get_all(std::vector<std::string>* endpoints) { - for (SubMap::const_iterator it = _stub_map.begin(); it != _stub_map.end(); ++it) { - endpoints->emplace_back(endpoint2str(it->first).c_str()); - } - } - - inline bool available(std::shared_ptr<PBackendService_Stub> stub, - const butil::EndPoint& endpoint) { - if (!stub) { - return false; - } - PHandShakeRequest request; - PHandShakeResponse response; - brpc::Controller cntl; - stub->hand_shake(&cntl, &request, &response, nullptr); - if (!cntl.Failed()) { - return true; - } else { - LOG(WARNING) << "open brpc connection to " << endpoint2str(endpoint).c_str() - << " failed: " << cntl.ErrorText(); - return false; - } - } - - inline bool available(std::shared_ptr<PBackendService_Stub> stub, const std::string& host, - int port) { - butil::EndPoint endpoint; - if (str2endpoint(host.c_str(), port, &endpoint)) { - LOG(WARNING) << "unknown endpoint, hostname=" << host; - return false; - } - return available(stub, endpoint); - } - -private: - SubMap _stub_map; -}; - -} // namespace doris diff --git a/be/src/util/doris_metrics.h b/be/src/util/doris_metrics.h index 67d60a3..8015dca 100644 --- a/be/src/util/doris_metrics.h +++ b/be/src/util/doris_metrics.h @@ -28,18 +28,19 @@ namespace doris { -#define REGISTER_ENTITY_HOOK_METRIC(entity, owner, metric, func) \ - owner->metric = (UIntGauge*)(entity->register_metric<UIntGauge>(&METRIC_##metric)); \ +#define REGISTER_ENTITY_HOOK_METRIC(entity, owner, metric, func) \ + owner->metric = (UIntGauge*)(entity->register_metric<UIntGauge>(&METRIC_##metric)); \ entity->register_hook(#metric, [&]() { owner->metric->set_value(func()); }); -#define REGISTER_HOOK_METRIC(metric, func) \ - REGISTER_ENTITY_HOOK_METRIC(DorisMetrics::instance()->server_entity(), DorisMetrics::instance(), metric, func) +#define REGISTER_HOOK_METRIC(metric, func) \ + REGISTER_ENTITY_HOOK_METRIC(DorisMetrics::instance()->server_entity(), \ + DorisMetrics::instance(), metric, func) -#define DEREGISTER_ENTITY_HOOK_METRIC(entity, name) \ - entity->deregister_metric(&METRIC_##name); \ +#define DEREGISTER_ENTITY_HOOK_METRIC(entity, name) \ + entity->deregister_metric(&METRIC_##name); \ entity->deregister_hook(#name); -#define DEREGISTER_HOOK_METRIC(name) \ +#define DEREGISTER_HOOK_METRIC(name) \ DEREGISTER_ENTITY_HOOK_METRIC(DorisMetrics::instance()->server_entity(), name) class DorisMetrics { @@ -177,6 +178,7 @@ public: UIntGauge* small_file_cache_count; UIntGauge* stream_load_pipe_count; UIntGauge* brpc_endpoint_stub_count; + UIntGauge* brpc_function_endpoint_stub_count; UIntGauge* tablet_writer_count; UIntGauge* compaction_mem_consumption; diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index 09e86e2..6201c67 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -141,6 +141,7 @@ set(VEC_FILES functions/function_date_or_datetime_to_string.cpp functions/function_datetime_string_to_string.cpp functions/function_grouping.cpp + functions/function_rpc.cpp olap/vgeneric_iterators.cpp olap/vcollect_iterator.cpp olap/block_reader.cpp diff --git a/be/src/vec/columns/column_decimal.h b/be/src/vec/columns/column_decimal.h index 017d891..b4b4a68 100644 --- a/be/src/vec/columns/column_decimal.h +++ b/be/src/vec/columns/column_decimal.h @@ -25,8 +25,8 @@ #include "vec/columns/column.h" #include "vec/columns/column_impl.h" #include "vec/columns/column_vector_helper.h" -#include "vec/common/typeid_cast.h" #include "vec/common/assert_cast.h" +#include "vec/common/typeid_cast.h" #include "vec/core/field.h" namespace doris::vectorized { @@ -97,7 +97,8 @@ public: data.push_back(static_cast<const Self&>(src).get_data()[n]); } - void insert_indices_from(const IColumn& src, const int* indices_begin, const int* indices_end) override { + void insert_indices_from(const IColumn& src, const int* indices_begin, + const int* indices_end) override { const Self& src_vec = assert_cast<const Self&>(src); data.reserve(size() + (indices_end - indices_begin)); for (auto x = indices_begin; x != indices_end; ++x) { @@ -226,4 +227,8 @@ ColumnPtr ColumnDecimal<T>::index_impl(const PaddedPODArray<Type>& indexes, size return res; } +using ColumnDecimal32 = ColumnDecimal<Decimal32>; +using ColumnDecimal64 = ColumnDecimal<Decimal64>; +using ColumnDecimal128 = ColumnDecimal<Decimal128>; + } // namespace doris::vectorized diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index deecc16..6f01a12 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -25,6 +25,7 @@ #include "udf/udf_internal.h" #include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_number.h" +#include "vec/functions/function_rpc.h" #include "vec/functions/simple_function_factory.h" namespace doris::vectorized { @@ -42,8 +43,13 @@ doris::Status VectorizedFnCall::prepare(doris::RuntimeState* state, argument_template.emplace_back(std::move(column), child->data_type(), child->expr_name()); child_expr_name.emplace_back(child->expr_name()); } - _function = SimpleFunctionFactory::instance().get_function(_fn.name.function_name, - argument_template, _data_type); + if (_fn.binary_type == TFunctionBinaryType::RPC) { + _function = RPCFnCall::create(_fn.name.function_name, _fn.hdfs_location, argument_template, + _data_type); + } else { + _function = SimpleFunctionFactory::instance().get_function(_fn.name.function_name, + argument_template, _data_type); + } if (_function == nullptr) { return Status::InternalError( fmt::format("Function {} is not implemented", _fn.name.function_name)); diff --git a/be/src/vec/functions/function_rpc.cpp b/be/src/vec/functions/function_rpc.cpp new file mode 100644 index 0000000..43d5a69 --- /dev/null +++ b/be/src/vec/functions/function_rpc.cpp @@ -0,0 +1,527 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/functions/function_rpc.h" + +#include <fmt/format.h> + +#include <memory> + +#include "gen_cpp/function_service.pb.h" +#include "runtime/exec_env.h" +#include "runtime/user_function_cache.h" +#include "service/brpc.h" +#include "util/brpc_client_cache.h" +#include "vec/columns/column_vector.h" +#include "vec/core/block.h" +#include "vec/data_types/data_type_bitmap.h" +#include "vec/data_types/data_type_date.h" +#include "vec/data_types/data_type_date_time.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" + +namespace doris::vectorized { +RPCFnCall::RPCFnCall(const std::string& symbol, const std::string& server, + const DataTypes& argument_types, const DataTypePtr& return_type) + : _symbol(symbol), + _server(server), + _name(fmt::format("{}/{}", server, symbol)), + _argument_types(argument_types), + _return_type(return_type) {} +Status RPCFnCall::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) { + _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server); + + if (_client == nullptr) { + return Status::InternalError("rpc env init error"); + } + return Status::OK(); +} + +template <bool nullable> +void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type, PValues* arg, + size_t row_count) { + PGenericType* ptype = arg->mutable_type(); + switch (data_type->get_type_id()) { + case TypeIndex::UInt8: { + ptype->set_id(PGenericType::UINT8); + auto* values = arg->mutable_bool_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnUInt8>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::UInt16: { + ptype->set_id(PGenericType::UINT16); + auto* values = arg->mutable_uint32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnUInt16>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::UInt32: { + ptype->set_id(PGenericType::UINT32); + auto* values = arg->mutable_uint32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnUInt32>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::UInt64: { + ptype->set_id(PGenericType::UINT64); + auto* values = arg->mutable_uint64_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnUInt64>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::UInt128: { + ptype->set_id(PGenericType::UINT128); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case TypeIndex::Int8: { + ptype->set_id(PGenericType::INT8); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnInt8>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Int16: { + ptype->set_id(PGenericType::INT16); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnInt16>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Int32: { + ptype->set_id(PGenericType::INT32); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnInt32>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Int64: { + ptype->set_id(PGenericType::INT64); + auto* values = arg->mutable_int64_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnInt64>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Int128: { + ptype->set_id(PGenericType::INT128); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case TypeIndex::Float32: { + ptype->set_id(PGenericType::FLOAT); + auto* values = arg->mutable_float_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnFloat32>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + + case TypeIndex::Float64: { + ptype->set_id(PGenericType::DOUBLE); + auto* values = arg->mutable_double_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column<ColumnFloat64>(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Decimal128: { + ptype->set_id(PGenericType::DECIMAL128); + auto dec_type = std::reinterpret_pointer_cast<const DataTypeDecimal<Decimal128>>(data_type); + ptype->mutable_decimal_type()->set_precision(dec_type->get_precision()); + ptype->mutable_decimal_type()->set_scale(dec_type->get_scale()); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case TypeIndex::String: { + ptype->set_id(PGenericType::STRING); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_string_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_string_value(data.to_string()); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_string_value(data.to_string()); + } + } + break; + } + case TypeIndex::Date: { + ptype->set_id(PGenericType::DATE); + arg->mutable_datetime_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + PDateTime* date_time = arg->add_datetime_value(); + if constexpr (nullable) { + if (!column->is_null_at(row_num)) { + VecDateTimeValue v = VecDateTimeValue(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + } + } else { + VecDateTimeValue v = VecDateTimeValue(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + } + } + break; + } + case TypeIndex::DateTime: { + ptype->set_id(PGenericType::DATETIME); + arg->mutable_datetime_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + PDateTime* date_time = arg->add_datetime_value(); + if constexpr (nullable) { + if (!column->is_null_at(row_num)) { + VecDateTimeValue v = VecDateTimeValue(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + date_time->set_hour(v.hour()); + date_time->set_minute(v.minute()); + date_time->set_second(v.second()); + } + } else { + VecDateTimeValue v = VecDateTimeValue(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + date_time->set_hour(v.hour()); + date_time->set_minute(v.minute()); + date_time->set_second(v.second()); + } + } + break; + } + case TypeIndex::BitMap: { + ptype->set_id(PGenericType::BITMAP); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + default: + LOG(INFO) << "unknown type: " << data_type->get_name(); + ptype->set_id(PGenericType::UNKNOWN); + break; + } +} + +void convert_nullable_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type, + const ColumnUInt8& null_col, PValues* arg, size_t row_count) { + if (column->has_null(row_count)) { + auto* null_map = arg->mutable_null_map(); + null_map->Reserve(row_count); + const auto* col = check_and_get_column<ColumnUInt8>(null_col); + auto& data = col->get_data(); + null_map->Add(data.begin(), data.begin() + row_count); + convert_col_to_pvalue<true>(column, data_type, arg, row_count); + } else { + convert_col_to_pvalue<false>(column, data_type, arg, row_count); + } +} + +void convert_block_to_proto(Block& block, const ColumnNumbers& arguments, size_t input_rows_count, + PFunctionCallRequest* request) { + size_t row_count = std::min(block.rows(), input_rows_count); + for (size_t col_idx : arguments) { + PValues* arg = request->add_args(); + ColumnWithTypeAndName& column = block.get_by_position(col_idx); + arg->set_has_null(column.column->has_null(row_count)); + auto col = column.column->convert_to_full_column_if_const(); + if (auto* nullable = check_and_get_column<const ColumnNullable>(*col)) { + auto data_col = nullable->get_nested_column_ptr(); + auto& null_col = nullable->get_null_map_column(); + auto data_type = std::reinterpret_pointer_cast<const DataTypeNullable>(column.type); + convert_nullable_col_to_pvalue(data_col->convert_to_full_column_if_const(), + data_type->get_nested_type(), null_col, arg, row_count); + } else { + convert_col_to_pvalue<false>(col, column.type, arg, row_count); + } + } +} + +template <bool nullable> +void convert_to_column(MutableColumnPtr& column, const PValues& result) { + switch (result.type().id()) { + case PGenericType::UINT8: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast<ColumnUInt8*>(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT16: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast<ColumnUInt16*>(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT32: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast<ColumnUInt32*>(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT64: { + column->reserve(result.uint64_value_size()); + column->resize(result.uint64_value_size()); + auto& data = reinterpret_cast<ColumnUInt64*>(column.get())->get_data(); + for (int i = 0; i < result.uint64_value_size(); ++i) { + data[i] = result.uint64_value(i); + } + break; + } + case PGenericType::INT8: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast<ColumnInt16*>(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT16: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast<ColumnInt16*>(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT32: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast<ColumnInt32*>(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT64: { + column->reserve(result.int64_value_size()); + column->resize(result.int64_value_size()); + auto& data = reinterpret_cast<ColumnInt64*>(column.get())->get_data(); + for (int i = 0; i < result.int64_value_size(); ++i) { + data[i] = result.int64_value(i); + } + break; + } + case PGenericType::DATE: + case PGenericType::DATETIME: { + column->reserve(result.datetime_value_size()); + column->resize(result.datetime_value_size()); + auto& data = reinterpret_cast<ColumnInt64*>(column.get())->get_data(); + for (int i = 0; i < result.datetime_value_size(); ++i) { + VecDateTimeValue v; + PDateTime pv = result.datetime_value(i); + v.set_time(pv.year(), pv.month(), pv.day(), pv.hour(), pv.minute(), pv.minute()); + data[i] = binary_cast<VecDateTimeValue, Int64>(v); + } + break; + } + case PGenericType::FLOAT: { + column->reserve(result.float_value_size()); + column->resize(result.float_value_size()); + auto& data = reinterpret_cast<ColumnFloat32*>(column.get())->get_data(); + for (int i = 0; i < result.float_value_size(); ++i) { + data[i] = result.float_value(i); + } + break; + } + case PGenericType::DOUBLE: { + column->reserve(result.double_value_size()); + column->resize(result.double_value_size()); + auto& data = reinterpret_cast<ColumnFloat64*>(column.get())->get_data(); + for (int i = 0; i < result.double_value_size(); ++i) { + data[i] = result.double_value(i); + } + break; + } + case PGenericType::INT128: { + column->reserve(result.bytes_value_size()); + column->resize(result.bytes_value_size()); + auto& data = reinterpret_cast<ColumnInt128*>(column.get())->get_data(); + for (int i = 0; i < result.bytes_value_size(); ++i) { + data[i] = *(int128_t*)(result.bytes_value(i).c_str()); + } + break; + } + case PGenericType::STRING: { + column->reserve(result.string_value_size()); + for (int i = 0; i < result.string_value_size(); ++i) { + column->insert_data(result.string_value(i).c_str(), result.string_value(i).size()); + } + break; + } + case PGenericType::DECIMAL128: { + column->reserve(result.bytes_value_size()); + column->resize(result.bytes_value_size()); + auto& data = reinterpret_cast<ColumnDecimal128*>(column.get())->get_data(); + for (int i = 0; i < result.bytes_value_size(); ++i) { + data[i] = *(int128_t*)(result.bytes_value(i).c_str()); + } + break; + } + case PGenericType::BITMAP: { + column->reserve(result.bytes_value_size()); + for (int i = 0; i < result.bytes_value_size(); ++i) { + column->insert_data(result.bytes_value(i).c_str(), result.bytes_value(i).size()); + } + break; + } + default: { + LOG(WARNING) << "unknown PGenericType: " << result.type().DebugString(); + break; + } + } +} + +void convert_to_block(Block& block, const PValues& result, size_t pos) { + auto data_type = block.get_data_type(pos); + if (data_type->is_nullable()) { + auto null_type = std::reinterpret_pointer_cast<const DataTypeNullable>(data_type); + auto data_col = null_type->get_nested_type()->create_column(); + convert_to_column<true>(data_col, result); + auto null_col = ColumnUInt8::create(data_col->size(), 0); + auto& null_map_data = null_col->get_data(); + null_col->reserve(data_col->size()); + null_col->resize(data_col->size()); + if (result.has_null()) { + for (int i = 0; i < data_col->size(); ++i) { + null_map_data[i] = result.null_map(i); + } + } else { + for (int i = 0; i < data_col->size(); ++i) { + null_map_data[i] = false; + } + } + block.replace_by_position( + pos, std::move(ColumnNullable::create(std::move(data_col), std::move(null_col)))); + } else { + auto column = data_type->create_column(); + convert_to_column<false>(column, result); + block.replace_by_position(pos, std::move(column)); + } +} + +Status RPCFnCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count, bool dry_run) { + PFunctionCallRequest request; + PFunctionCallResponse response; + request.set_function_name(_symbol); + convert_block_to_proto(block, arguments, input_rows_count, &request); + brpc::Controller cntl; + _client->fn_call(&cntl, &request, &response, nullptr); + if (cntl.Failed()) { + return Status::InternalError( + fmt::format("call to rpc function {} failed: {}", _symbol, cntl.ErrorText()) + .c_str()); + } + if (response.status().status_code() != 0) { + return Status::InternalError(fmt::format("call to rpc function {} failed: {}", _symbol, + response.status().DebugString())); + } + convert_to_block(block, response.result(), result); + return Status::OK(); +} +} // namespace doris::vectorized diff --git a/be/src/vec/functions/function_rpc.h b/be/src/vec/functions/function_rpc.h new file mode 100644 index 0000000..2c7535a --- /dev/null +++ b/be/src/vec/functions/function_rpc.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/functions/function.h" + +namespace doris { +class PFunctionService_Stub; + +namespace vectorized { +class RPCFnCall : public IFunctionBase { +public: + RPCFnCall(const std::string& symbol, const std::string& server, const DataTypes& argument_types, + const DataTypePtr& return_type); + static FunctionBasePtr create(const std::string& symbol, const std::string& server, + const ColumnsWithTypeAndName& argument_types, + const DataTypePtr& return_type) { + DataTypes data_types(argument_types.size()); + for (size_t i = 0; i < argument_types.size(); ++i) { + data_types[i] = argument_types[i].type; + } + return std::make_shared<RPCFnCall>(symbol, server, data_types, return_type); + } + + /// Get the main function name. + String get_name() const override { return _name; }; + + const DataTypes& get_argument_types() const override { return _argument_types; }; + const DataTypePtr& get_return_type() const override { return _return_type; }; + + PreparedFunctionPtr prepare(FunctionContext* context, const Block& sample_block, + const ColumnNumbers& arguments, size_t result) const override { + return nullptr; + } + + Status prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) override; + + Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count, bool dry_run = false) override; + + bool is_deterministic() const override { return false; } + + bool is_deterministic_in_scope_of_query() const override { return false; } + +private: + std::string _symbol; + std::string _server; + std::string _name; + DataTypes _argument_types; + DataTypePtr _return_type; + std::shared_ptr<PFunctionService_Stub> _client = nullptr; +}; + +} // namespace vectorized +} // namespace doris diff --git a/be/src/vec/sink/vdata_stream_sender.cpp b/be/src/vec/sink/vdata_stream_sender.cpp index e3e081d..295891e 100644 --- a/be/src/vec/sink/vdata_stream_sender.cpp +++ b/be/src/vec/sink/vdata_stream_sender.cpp @@ -54,13 +54,13 @@ Status VDataStreamSender::Channel::init(RuntimeState* state) { _brpc_request.set_be_number(_be_number); _brpc_timeout_ms = std::min(3600, state->query_options().query_timeout) * 1000; - _brpc_stub = state->exec_env()->brpc_stub_cache()->get_stub(_brpc_dest_addr); + _brpc_stub = state->exec_env()->brpc_internal_client_cache()->get_client(_brpc_dest_addr); if (_brpc_dest_addr.hostname == BackendOptions::get_localhost()) { - _brpc_stub = - state->exec_env()->brpc_stub_cache()->get_stub("127.0.0.1", _brpc_dest_addr.port); + _brpc_stub = state->exec_env()->brpc_internal_client_cache()->get_client( + "127.0.0.1", _brpc_dest_addr.port); } else { - _brpc_stub = state->exec_env()->brpc_stub_cache()->get_stub(_brpc_dest_addr); + _brpc_stub = state->exec_env()->brpc_internal_client_cache()->get_client(_brpc_dest_addr); } // In bucket shuffle join will set fragment_instance_id (-1, -1) diff --git a/be/src/vec/sink/vdata_stream_sender.h b/be/src/vec/sink/vdata_stream_sender.h index 223bf28..6ed99bd 100644 --- a/be/src/vec/sink/vdata_stream_sender.h +++ b/be/src/vec/sink/vdata_stream_sender.h @@ -25,7 +25,7 @@ #include "runtime/descriptors.h" #include "service/backend_options.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/network_util.h" #include "util/ref_count_closure.h" #include "util/uid_util.h" @@ -81,7 +81,8 @@ private: } template <typename Channels, typename HashVals> - Status channel_add_rows(Channels& channels, int num_channels, const HashVals& hash_vals, int rows, Block* block); + Status channel_add_rows(Channels& channels, int num_channels, const HashVals& hash_vals, + int rows, Block* block); struct hash_128 { uint64_t high; @@ -159,13 +160,14 @@ public: _brpc_dest_addr(brpc_dest), _is_transfer_chain(is_transfer_chain), _send_query_statistics_with_every_batch(send_query_statistics_with_every_batch) { - std::string localhost = BackendOptions::get_localhost(); - _is_local = (_brpc_dest_addr.hostname == localhost) && (_brpc_dest_addr.port == config::brpc_port); - if (_is_local) { - LOG(INFO) << "will use local Exchange, dest_node_id is : "<<_dest_node_id; - } - } - + std::string localhost = BackendOptions::get_localhost(); + _is_local = (_brpc_dest_addr.hostname == localhost) && + (_brpc_dest_addr.port == config::brpc_port); + if (_is_local) { + LOG(INFO) << "will use local Exchange, dest_node_id is : " << _dest_node_id; + } + } + virtual ~Channel() { if (_closure != nullptr && _closure->unref()) { delete _closure; @@ -235,7 +237,6 @@ private: return Status::OK(); } - private: // Serialize _batch into _thrift_batch and send via send_batch(). // Returns send_batch() status. @@ -276,7 +277,8 @@ private: }; template <typename Channels, typename HashVals> -Status VDataStreamSender::channel_add_rows(Channels& channels, int num_channels, const HashVals& hash_vals, int rows, Block* block) { +Status VDataStreamSender::channel_add_rows(Channels& channels, int num_channels, + const HashVals& hash_vals, int rows, Block* block) { std::vector<int> channel2rows[num_channels]; for (int i = 0; i < rows; i++) { diff --git a/be/test/exec/tablet_sink_test.cpp b/be/test/exec/tablet_sink_test.cpp index 44c5fbd..3d55699 100644 --- a/be/test/exec/tablet_sink_test.cpp +++ b/be/test/exec/tablet_sink_test.cpp @@ -34,7 +34,7 @@ #include "runtime/types.h" #include "runtime/tuple_row.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/cpu_info.h" #include "util/debug/leakcheck_disabler.h" #include "util/proto_util.h" @@ -54,7 +54,8 @@ public: _env->_thread_mgr = new ThreadResourceMgr(); _env->_master_info = new TMasterInfo(); _env->_load_stream_mgr = new LoadStreamMgr(); - _env->_brpc_stub_cache = new BrpcStubCache(); + _env->_internal_client_cache = new BrpcClientCache<PBackendService_Stub>(); + _env->_function_client_cache = new BrpcClientCache<PFunctionService_Stub>(); _env->_buffer_reservation = new ReservationTracker(); ThreadPoolBuilder("SendBatchThreadPool") .set_min_threads(1) @@ -66,7 +67,8 @@ public: } void TearDown() override { - SAFE_DELETE(_env->_brpc_stub_cache); + SAFE_DELETE(_env->_internal_client_cache); + SAFE_DELETE(_env->_function_client_cache); SAFE_DELETE(_env->_load_stream_mgr); SAFE_DELETE(_env->_master_info); SAFE_DELETE(_env->_thread_mgr); diff --git a/be/test/http/stream_load_test.cpp b/be/test/http/stream_load_test.cpp index 0ea97a0..fc3435f 100644 --- a/be/test/http/stream_load_test.cpp +++ b/be/test/http/stream_load_test.cpp @@ -30,7 +30,7 @@ #include "runtime/stream_load/load_stream_mgr.h" #include "runtime/stream_load/stream_load_executor.h" #include "runtime/thread_resource_mgr.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/cpu_info.h" class mg_connection; @@ -74,14 +74,17 @@ public: _env._thread_mgr = new ThreadResourceMgr(); _env._master_info = new TMasterInfo(); _env._load_stream_mgr = new LoadStreamMgr(); - _env._brpc_stub_cache = new BrpcStubCache(); + _env._internal_client_cache = new BrpcClientCache<PBackendService_Stub>(); + _env._function_client_cache = new BrpcClientCache<PFunctionService_Stub>(); _env._stream_load_executor = new StreamLoadExecutor(&_env); _evhttp_req = evhttp_request_new(nullptr, nullptr); } void TearDown() override { - delete _env._brpc_stub_cache; - _env._brpc_stub_cache = nullptr; + delete _env._internal_client_cache; + _env._internal_client_cache = nullptr; + delete _env._function_client_cache; + _env._function_client_cache = nullptr; delete _env._load_stream_mgr; _env._load_stream_mgr = nullptr; delete _env._master_info; diff --git a/be/test/util/CMakeLists.txt b/be/test/util/CMakeLists.txt index e9f75a1..afa332c 100644 --- a/be/test/util/CMakeLists.txt +++ b/be/test/util/CMakeLists.txt @@ -19,7 +19,7 @@ set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/test/util") ADD_BE_TEST(bit_util_test) -ADD_BE_TEST(brpc_stub_cache_test) +ADD_BE_TEST(brpc_client_cache_test) ADD_BE_TEST(path_trie_test) ADD_BE_TEST(coding_test) ADD_BE_TEST(crc32c_test) diff --git a/be/test/util/brpc_stub_cache_test.cpp b/be/test/util/brpc_client_cache_test.cpp similarity index 73% rename from be/test/util/brpc_stub_cache_test.cpp rename to be/test/util/brpc_client_cache_test.cpp index cf68cc1..c6ece74 100644 --- a/be/test/util/brpc_stub_cache_test.cpp +++ b/be/test/util/brpc_client_cache_test.cpp @@ -15,40 +15,40 @@ // specific language governing permissions and limitations // under the License. -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include <gtest/gtest.h> namespace doris { -class BrpcStubCacheTest : public testing::Test { +class BrpcClientCacheTest : public testing::Test { public: - BrpcStubCacheTest() {} - virtual ~BrpcStubCacheTest() {} + BrpcClientCacheTest() {} + virtual ~BrpcClientCacheTest() {} }; -TEST_F(BrpcStubCacheTest, normal) { - BrpcStubCache cache; +TEST_F(BrpcClientCacheTest, normal) { + BrpcClientCache<PBackendService_Stub> cache; TNetworkAddress address; address.hostname = "127.0.0.1"; address.port = 123; - auto stub1 = cache.get_stub(address); + auto stub1 = cache.get_client(address); ASSERT_NE(nullptr, stub1); address.port = 124; - auto stub2 = cache.get_stub(address); + auto stub2 = cache.get_client(address); ASSERT_NE(nullptr, stub2); ASSERT_NE(stub1, stub2); address.port = 123; - auto stub3 = cache.get_stub(address); + auto stub3 = cache.get_client(address); ASSERT_EQ(stub1, stub3); } -TEST_F(BrpcStubCacheTest, invalid) { - BrpcStubCache cache; +TEST_F(BrpcClientCacheTest, invalid) { + BrpcClientCache<PBackendService_Stub> cache; TNetworkAddress address; address.hostname = "invalid.cm.invalid"; address.port = 123; - auto stub1 = cache.get_stub(address); + auto stub1 = cache.get_client(address); ASSERT_EQ(nullptr, stub1); } diff --git a/be/test/vec/runtime/vdata_stream_test.cpp b/be/test/vec/runtime/vdata_stream_test.cpp index cc4d429..5fef761 100644 --- a/be/test/vec/runtime/vdata_stream_test.cpp +++ b/be/test/vec/runtime/vdata_stream_test.cpp @@ -65,18 +65,19 @@ private: std::unique_ptr<PBackendService> _service; }; -class MockBrpcStubCache : public BrpcStubCache { +template <class T> +class MockBrpcClientCache : public BrpcClientCache<T> { public: - MockBrpcStubCache(google::protobuf::RpcChannel* channel) { + MockBrpcClientCache(google::protobuf::RpcChannel* channel) { _channel.reset(channel); - _stub.reset(new PBackendService_Stub(channel)); + _stub.reset(new T(channel)); } - virtual ~MockBrpcStubCache() = default; - virtual std::shared_ptr<PBackendService_Stub> get_stub(const TNetworkAddress&) { return _stub; } + virtual ~MockBrpcClientCache() = default; + virtual std::shared_ptr<T> get_client(const TNetworkAddress&) { return _stub; } private: std::unique_ptr<google::protobuf::RpcChannel> _channel; - std::shared_ptr<PBackendService_Stub> _stub; + std::shared_ptr<T> _stub; }; class VDataStreamTest : public testing::Test { @@ -107,8 +108,8 @@ TEST_F(VDataStreamTest, BasicTest) { mock_service->stream_mgr = &_instance; MockChannel* channel = new MockChannel(std::move(mock_service)); - runtime_stat._exec_env->_brpc_stub_cache = - _object_pool.add(new MockBrpcStubCache(std::move(channel))); + runtime_stat._exec_env->_internal_client_cache = + _object_pool.add(new MockBrpcClientCache<PBackendService_Stub>(std::move(channel))); TUniqueId uid; PlanNodeId nid = 1; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index 6e376ad..a0e2ccb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -22,21 +22,30 @@ import org.apache.doris.catalog.AliasFunction; import org.apache.doris.catalog.Catalog; import org.apache.doris.catalog.Function; import org.apache.doris.catalog.ScalarFunction; +import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; +import org.apache.doris.common.Config; import org.apache.doris.common.ErrorCode; import org.apache.doris.common.ErrorReport; import org.apache.doris.common.FeConstants; import org.apache.doris.common.UserException; import org.apache.doris.common.util.Util; import org.apache.doris.mysql.privilege.PrivPredicate; +import org.apache.doris.proto.FunctionService; +import org.apache.doris.proto.PFunctionServiceGrpc; +import org.apache.doris.proto.Types; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.thrift.TFunctionBinaryType; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSortedMap; import org.apache.commons.codec.binary.Hex; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import java.io.IOException; import java.io.InputStream; @@ -45,8 +54,12 @@ import java.security.NoSuchAlgorithmException; import java.util.List; import java.util.Map; +import io.grpc.ManagedChannel; +import io.grpc.netty.NettyChannelBuilder; + // create a user define function public class CreateFunctionStmt extends DdlStmt { + private final static Logger LOG = LogManager.getLogger(CreateFunctionStmt.class); public static final String OBJECT_FILE_KEY = "object_file"; public static final String SYMBOL_KEY = "symbol"; public static final String PREPARE_SYMBOL_KEY = "prepare_fn"; @@ -59,6 +72,7 @@ public class CreateFunctionStmt extends DdlStmt { public static final String FINALIZE_KEY = "finalize_fn"; public static final String GET_VALUE_KEY = "get_value_fn"; public static final String REMOVE_KEY = "remove_fn"; + public static final String BINARY_TYPE = "type"; private final FunctionName functionName; private final boolean isAggregate; @@ -69,11 +83,12 @@ public class CreateFunctionStmt extends DdlStmt { private final Map<String, String> properties; private final List<String> parameters; private final Expr originFunction; + TFunctionBinaryType binaryType = TFunctionBinaryType.NATIVE; // needed item set after analyzed private String objectFile; private Function function; - private String checksum; + private String checksum = ""; // timeout for both connection and read. 10 seconds is long enough. private static final int HTTP_TIMEOUT_MS = 10000; @@ -111,8 +126,13 @@ public class CreateFunctionStmt extends DdlStmt { this.properties = ImmutableSortedMap.of(); } - public FunctionName getFunctionName() { return functionName; } - public Function getFunction() { return function; } + public FunctionName getFunctionName() { + return functionName; + } + + public Function getFunction() { + return function; + } public Expr getOriginFunction() { return originFunction; @@ -156,26 +176,32 @@ public class CreateFunctionStmt extends DdlStmt { intermediateType = returnType; } + String type = properties.getOrDefault(BINARY_TYPE, "NATIVE"); + binaryType = getFunctionBinaryType(type); + if (binaryType == null) { + throw new AnalysisException("unknown function type"); + } + objectFile = properties.get(OBJECT_FILE_KEY); if (Strings.isNullOrEmpty(objectFile)) { throw new AnalysisException("No 'object_file' in properties"); } - try { - computeObjectChecksum(); - } catch (IOException | NoSuchAlgorithmException e) { - throw new AnalysisException("cannot to compute object's checksum"); - } - - String md5sum = properties.get(MD5_CHECKSUM); - if (md5sum != null && !md5sum.equalsIgnoreCase(checksum)) { - throw new AnalysisException("library's checksum is not equal with input, checksum=" + checksum); + if (binaryType != TFunctionBinaryType.RPC) { + try { + computeObjectChecksum(); + } catch (IOException | NoSuchAlgorithmException e) { + throw new AnalysisException("cannot to compute object's checksum"); + } + String md5sum = properties.get(MD5_CHECKSUM); + if (md5sum != null && !md5sum.equalsIgnoreCase(checksum)) { + throw new AnalysisException("library's checksum is not equal with input, checksum=" + checksum); + } } } private void computeObjectChecksum() throws IOException, NoSuchAlgorithmException { if (FeConstants.runningUnitTest) { // skip checking checksum when running ut - checksum = ""; return; } @@ -196,6 +222,9 @@ public class CreateFunctionStmt extends DdlStmt { } private void analyzeUda() throws AnalysisException { + if (binaryType == TFunctionBinaryType.RPC) { + throw new AnalysisException("RPC UDAF is not supported."); + } AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder(); builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType()). @@ -227,13 +256,111 @@ public class CreateFunctionStmt extends DdlStmt { } String prepareFnSymbol = properties.get(PREPARE_SYMBOL_KEY); String closeFnSymbol = properties.get(CLOSE_SYMBOL_KEY); - function = ScalarFunction.createUdf( + // TODO(yangzhg) support check function in FE when function service behind load balancer + // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster + if (binaryType == TFunctionBinaryType.RPC && !objectFile.contains("://")) { + if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) { + throw new AnalysisException(" prepare and close in RPC UDF are not supported."); + } + String[] url = objectFile.split(":"); + if (url.length != 2) { + throw new AnalysisException("function server address invalid."); + } + String host = url[0]; + int port = Integer.valueOf(url[1]); + ManagedChannel channel = NettyChannelBuilder.forAddress(host, port) + .flowControlWindow(Config.grpc_max_message_size_bytes) + .maxInboundMessageSize(Config.grpc_max_message_size_bytes) + .enableRetry().maxRetryAttempts(3) + .usePlaintext().build(); + PFunctionServiceGrpc.PFunctionServiceBlockingStub stub = PFunctionServiceGrpc.newBlockingStub(channel); + FunctionService.PCheckFunctionRequest.Builder builder = FunctionService.PCheckFunctionRequest.newBuilder(); + builder.getFunctionBuilder().setFunctionName(functionName.getFunction()); + for (Type arg : argsDef.getArgTypes()) { + builder.getFunctionBuilder().addInputs(convertToPParameterType(arg)); + } + builder.getFunctionBuilder().setOutput(convertToPParameterType(returnType.getType())); + FunctionService.PCheckFunctionResponse response = stub.checkFn(builder.build()); + if (response.getStatus().getStatusCode() != 0) { + throw new AnalysisException("cannot access function server:" + response.getStatus()); + } + } + function = ScalarFunction.createUdf(binaryType, functionName, argsDef.getArgTypes(), returnType.getType(), argsDef.isVariadic(), objectFile, symbol, prepareFnSymbol, closeFnSymbol); function.setChecksum(checksum); } + private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException { + Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder(); + switch (arg.getPrimitiveType()) { + case INVALID_TYPE: + typeBuilder.setId(Types.PGenericType.TypeId.UNKNOWN); + break; + case BOOLEAN: + typeBuilder.setId(Types.PGenericType.TypeId.BOOLEAN); + break; + case SMALLINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT16); + break; + case TINYINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT8); + break; + case INT: + typeBuilder.setId(Types.PGenericType.TypeId.INT32); + break; + case BIGINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT64); + break; + case FLOAT: + typeBuilder.setId(Types.PGenericType.TypeId.FLOAT); + break; + case DOUBLE: + typeBuilder.setId(Types.PGenericType.TypeId.DOUBLE); + break; + case CHAR: + case VARCHAR: + typeBuilder.setId(Types.PGenericType.TypeId.STRING); + break; + case HLL: + typeBuilder.setId(Types.PGenericType.TypeId.HLL); + break; + case BITMAP: + typeBuilder.setId(Types.PGenericType.TypeId.BITMAP); + break; + case DATE: + typeBuilder.setId(Types.PGenericType.TypeId.DATE); + break; + case DATETIME: + case TIME: + typeBuilder.setId(Types.PGenericType.TypeId.DATETIME); + break; + case DECIMALV2: + typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL128) + .getDecimalTypeBuilder() + .setPrecision(((ScalarType) arg).getScalarPrecision()) + .setScale(((ScalarType) arg).getScalarScale()); + break; + case LARGEINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT128); + break; + default: + throw new AnalysisException("type " + arg.getPrimitiveType().toString() + " is not supported"); + } + return typeBuilder.build(); + } + + private TFunctionBinaryType getFunctionBinaryType(String type) { + TFunctionBinaryType binaryType = null; + try { + binaryType = TFunctionBinaryType.valueOf(type); + } catch (IllegalArgumentException e) { + // ignore enum Exception + } + return binaryType; + } + private void analyzeAliasFunction() throws AnalysisException { function = AliasFunction.createFunction(functionName, argsDef.getArgTypes(), Type.VARCHAR, argsDef.isVariadic(), parameters, originFunction); @@ -279,8 +406,8 @@ public class CreateFunctionStmt extends DdlStmt { } return stringBuilder.toString(); } - - @Override + + @Override public RedirectStatus getRedirectStatus() { return RedirectStatus.FORWARD_WITH_SYNC; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java index 5f216d3..308bda0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java @@ -312,11 +312,12 @@ public class ScalarFunction extends Function { } public static ScalarFunction createUdf( + TFunctionBinaryType binaryType, FunctionName name, Type[] args, Type returnType, boolean isVariadic, String objectFile, String symbol, String prepareFnSymbol, String closeFnSymbol) { - ScalarFunction fn = new ScalarFunction(name, Arrays.asList(args), returnType, isVariadic, - TFunctionBinaryType.NATIVE, true, false); + ScalarFunction fn = new ScalarFunction(name, Arrays.asList(args), returnType, isVariadic, binaryType, + true, false); fn.symbolName = symbol; fn.prepareFnSymbol = prepareFnSymbol; fn.closeFnSymbol = closeFnSymbol; diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/Status.java b/fe/fe-core/src/main/java/org/apache/doris/common/Status.java index 7d6b7c6..1104cc4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/Status.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/Status.java @@ -17,7 +17,7 @@ package org.apache.doris.common; -import org.apache.doris.proto.Status.PStatus; +import org.apache.doris.proto.Types.PStatus; import org.apache.doris.thrift.TStatus; import org.apache.doris.thrift.TStatusCode; diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java index 407bbea..27b072d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java @@ -2018,7 +2018,7 @@ public class Coordinator { public InternalService.PExecPlanFragmentResult get() { InternalService.PExecPlanFragmentResult result = InternalService.PExecPlanFragmentResult .newBuilder() - .setStatus(org.apache.doris.proto.Status.PStatus.newBuilder() + .setStatus(org.apache.doris.proto.Types.PStatus.newBuilder() .addErrorMsgs(e.getMessage()) .setStatusCode(TStatusCode.THRIFT_RPC_ERROR.getValue()) .build()) diff --git a/fe/fe-core/src/test/java/org/apache/doris/load/sync/canal/CanalSyncDataTest.java b/fe/fe-core/src/test/java/org/apache/doris/load/sync/canal/CanalSyncDataTest.java index 70815a3..a3051c6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/load/sync/canal/CanalSyncDataTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/load/sync/canal/CanalSyncDataTest.java @@ -24,7 +24,6 @@ import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Config; import org.apache.doris.planner.StreamLoadPlanner; import org.apache.doris.proto.InternalService; -import org.apache.doris.proto.Status; import org.apache.doris.proto.Types; import org.apache.doris.resource.Tag; import org.apache.doris.rpc.BackendServiceProxy; @@ -97,22 +96,22 @@ public class CanalSyncDataTest { SystemInfoService systemInfoService; InternalService.PExecPlanFragmentResult beginOkResult = InternalService.PExecPlanFragmentResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0).build()).build(); // begin txn OK + .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()).build(); // begin txn OK InternalService.PExecPlanFragmentResult beginFailResult = InternalService.PExecPlanFragmentResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(1).build()).build(); // begin txn CANCELLED + .setStatus(Types.PStatus.newBuilder().setStatusCode(1).build()).build(); // begin txn CANCELLED InternalService.PCommitResult commitOkResult = InternalService.PCommitResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0).build()).build(); // commit txn OK + .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()).build(); // commit txn OK InternalService.PCommitResult commitFailResult = InternalService.PCommitResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(1).build()).build(); // commit txn CANCELLED + .setStatus(Types.PStatus.newBuilder().setStatusCode(1).build()).build(); // commit txn CANCELLED InternalService.PRollbackResult abortOKResult = InternalService.PRollbackResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0).build()).build(); // abort txn OK + .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()).build(); // abort txn OK InternalService.PSendDataResult sendDataOKResult = InternalService.PSendDataResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0).build()).build(); // send data OK + .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()).build(); // send data OK @Before public void setUp() throws Exception { diff --git a/fe/fe-core/src/test/java/org/apache/doris/utframe/MockedBackendFactory.java b/fe/fe-core/src/test/java/org/apache/doris/utframe/MockedBackendFactory.java index b04b54a..42dab10 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/utframe/MockedBackendFactory.java +++ b/fe/fe-core/src/test/java/org/apache/doris/utframe/MockedBackendFactory.java @@ -21,7 +21,7 @@ import org.apache.doris.common.ClientPool; import org.apache.doris.proto.Data; import org.apache.doris.proto.InternalService; import org.apache.doris.proto.PBackendServiceGrpc; -import org.apache.doris.proto.Status; +import org.apache.doris.proto.Types; import org.apache.doris.thrift.BackendService; import org.apache.doris.thrift.FrontendService; import org.apache.doris.thrift.HeartbeatService; @@ -326,7 +326,7 @@ public class MockedBackendFactory { @Override public void transmitData(InternalService.PTransmitDataParams request, StreamObserver<InternalService.PTransmitDataResult> responseObserver) { responseObserver.onNext(InternalService.PTransmitDataResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)).build()); + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)).build()); responseObserver.onCompleted(); } @@ -334,7 +334,7 @@ public class MockedBackendFactory { public void execPlanFragment(InternalService.PExecPlanFragmentRequest request, StreamObserver<InternalService.PExecPlanFragmentResult> responseObserver) { System.out.println("get exec_plan_fragment request"); responseObserver.onNext(InternalService.PExecPlanFragmentResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)).build()); + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)).build()); responseObserver.onCompleted(); } @@ -342,7 +342,7 @@ public class MockedBackendFactory { public void cancelPlanFragment(InternalService.PCancelPlanFragmentRequest request, StreamObserver<InternalService.PCancelPlanFragmentResult> responseObserver) { System.out.println("get cancel_plan_fragment request"); responseObserver.onNext(InternalService.PCancelPlanFragmentResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)).build()); + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)).build()); responseObserver.onCompleted(); } @@ -350,7 +350,7 @@ public class MockedBackendFactory { public void fetchData(InternalService.PFetchDataRequest request, StreamObserver<InternalService.PFetchDataResult> responseObserver) { System.out.println("get fetch_data request"); responseObserver.onNext(InternalService.PFetchDataResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)) + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)) .setQueryStatistics(Data.PQueryStatistics.newBuilder() .setScanRows(0L) .setScanBytes(0L)) @@ -382,7 +382,7 @@ public class MockedBackendFactory { public void getInfo(InternalService.PProxyRequest request, StreamObserver<InternalService.PProxyResult> responseObserver) { System.out.println("get get_info request"); responseObserver.onNext(InternalService.PProxyResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)).build()); + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)).build()); responseObserver.onCompleted(); } diff --git a/gensrc/proto/function_service.proto b/gensrc/proto/function_service.proto new file mode 100644 index 0000000..561be9f --- /dev/null +++ b/gensrc/proto/function_service.proto @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax="proto2"; + +package doris; +option java_package = "org.apache.doris.proto"; +option cc_generic_services = true; + +import "types.proto"; + +message PRequestContext { + optional string id = 1; + optional PFunctionContext function_context = 2; +} + +message PFunctionCallRequest { + optional string function_name = 1; + repeated PValues args = 2; + optional PRequestContext context = 3; +} + +message PFunctionCallResponse { + optional PValues result = 1; + optional PStatus status = 2; +} + +message PCheckFunctionRequest { + enum MatchType { + IDENTICAL = 0; + INDISTINGUISHABLE = 1; + SUPERTYPE_OF = 2; + NONSTRICT_SUPERTYPE_OF = 3; + MATCHABLE = 4; + } + optional PFunction function = 1; + optional MatchType match_type = 2; +} + +message PCheckFunctionResponse { + optional PStatus status = 1; +} + +service PFunctionService { + rpc fn_call(PFunctionCallRequest) returns (PFunctionCallResponse); + rpc check_fn(PCheckFunctionRequest) returns (PCheckFunctionResponse); + rpc hand_shake(PHandShakeRequest) returns (PHandShakeResponse); +} + diff --git a/gensrc/proto/internal_service.proto b/gensrc/proto/internal_service.proto index d01a5fe..41a0dce 100644 --- a/gensrc/proto/internal_service.proto +++ b/gensrc/proto/internal_service.proto @@ -22,7 +22,6 @@ option java_package = "org.apache.doris.proto"; import "data.proto"; import "descriptors.proto"; -import "status.proto"; import "types.proto"; option cc_generic_services = true; @@ -430,15 +429,6 @@ message PResetRPCChannelResponse { repeated string channels = 2; }; -message PHandShakeRequest { - optional string hello = 1; -} - -message PHandShakeResponse { - optional PStatus status = 1; - optional string hello = 2; -} - service PBackendService { rpc transmit_data(PTransmitDataParams) returns (PTransmitDataResult); rpc exec_plan_fragment(PExecPlanFragmentRequest) returns (PExecPlanFragmentResult); diff --git a/gensrc/proto/status.proto b/gensrc/proto/status.proto deleted file mode 100644 index d1e9e7d..0000000 --- a/gensrc/proto/status.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -syntax="proto2"; - -package doris; -option java_package = "org.apache.doris.proto"; - -message PStatus { - required int32 status_code = 1; - repeated string error_msgs = 2; -}; - diff --git a/gensrc/proto/types.proto b/gensrc/proto/types.proto index 762229d..f7bff5c 100644 --- a/gensrc/proto/types.proto +++ b/gensrc/proto/types.proto @@ -20,6 +20,10 @@ syntax="proto2"; package doris; option java_package = "org.apache.doris.proto"; +message PStatus { + required int32 status_code = 1; + repeated string error_msgs = 2; +}; message PScalarType { // TPrimitiveType, use int32 to avoid redefine Enum required int32 type = 1; @@ -63,3 +67,150 @@ message PUniqueId { required int64 lo = 2; }; +message PGenericType { + enum TypeId { + UINT8 = 0; + UINT16 = 1; + UINT32 = 2; + UINT64 = 3; + UINT128 = 4; + UINT256 = 5; + INT8 = 6; + INT16 = 7; + INT32 = 8; + INT64 = 9; + INT128 = 10; + INT256 = 11; + FLOAT = 12; + DOUBLE = 13; + BOOLEAN = 14; + DATE = 15; + DATETIME = 16; + HLL = 17; + BITMAP = 18; + LIST = 19; + MAP = 20; + STRUCT =21; + STRING = 22; + DECIMAL32 = 23; + DECIMAL64 = 24; + DECIMAL128 = 25; + BYTES = 26; + NOTHING = 27; + UNKNOWN = 999; + } + required TypeId id = 2; + optional PList list_type = 11; + optional PMap map_type = 12; + optional PStruct struct_type = 13; + optional PDecimal decimal_type = 14; +} + +message PList { + required PGenericType element_type = 1; +} + +message PMap { + required PGenericType key_type = 1; + required PGenericType value_type = 2; +} + +message PField { + required PGenericType type = 1; + optional string name = 2; + optional string comment = 3; +} + +message PStruct { + repeated PField fields = 1; + required string name = 2; +} + +message PDecimal { + required uint32 precision = 1; + required uint32 scale = 2; +} + +message PDateTime { + optional int32 year = 1; + optional int32 month = 2; + optional int32 day = 3; + optional int32 hour = 4; + optional int32 minute = 5; + optional int32 second = 6; + optional int32 microsecond = 7; +} + +message PValue { + required PGenericType type = 1; + optional bool is_null = 2 [default = false]; + optional double double_value = 3; + optional float float_value = 4; + optional int32 int32_value = 5; + optional int64 int64_value = 6; + optional uint32 uint32_value = 7; + optional uint64 uint64_value = 8; + optional bool bool_value = 9; + optional string string_value = 10; + optional bytes bytes_value = 11; + optional PDateTime datetime_value = 12; +} + +message PValues { + required PGenericType type = 1; + optional bool has_null = 2 [default = false]; + repeated bool null_map = 3; + repeated double double_value = 4; + repeated float float_value = 5; + repeated int32 int32_value = 6; + repeated int64 int64_value = 7; + repeated uint32 uint32_value = 8; + repeated uint64 uint64_value = 9; + repeated bool bool_value = 10; + repeated string string_value = 11; + repeated bytes bytes_value = 12; + repeated PDateTime datetime_value = 13; +} + +// this mesage may not used for now +message PFunction { + enum FunctionType { + UDF = 0; + // not supported now + UDAF = 1; + UDTF = 2; + } + message Property { + required string key = 1; + required string val = 2; + }; + required string function_name = 1; + repeated PGenericType inputs = 2; + optional PGenericType output = 3; + optional FunctionType type = 4 [default = UDF]; + optional bool variadic = 5; + repeated Property properties = 6; +} + +message PFunctionContext { + optional string version = 1 [default = "V2_0"]; + repeated PValue staging_input_vals = 2; + repeated PValue constant_args = 3; + optional string error_msg = 4; + optional PUniqueId query_id = 5; + optional bytes thread_local_fn_state = 6; + optional bytes fragment_local_fn_state = 7; + optional string string_result = 8; + optional int64 num_updates = 9; + optional int64 num_removes = 10; + optional int64 num_warnings = 11; +} + +message PHandShakeRequest { + optional string hello = 1; +} + +message PHandShakeResponse { + optional PStatus status = 1; + optional string hello = 2; +} diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index 7dc77f3..c1c8487 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -254,7 +254,7 @@ enum TFunctionType { } enum TFunctionBinaryType { - // Palo builtin. We can either run this interpreted or via codegen + // Doris builtin. We can either run this interpreted or via codegen // depending on the query option. BUILTIN, @@ -266,6 +266,9 @@ enum TFunctionBinaryType { // Native-interface, precompiled to IR; loaded from *.ll IR, + + // call udfs by rpc service + RPC, } // Represents a fully qualified function name. diff --git a/run-be-ut.sh b/run-be-ut.sh index 51f27f5..904d197 100755 --- a/run-be-ut.sh +++ b/run-be-ut.sh @@ -135,6 +135,7 @@ ${CMAKE_CMD} -G "${GENERATOR}" \ -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" \ -DMAKE_TEST=ON \ -DGLIBC_COMPATIBILITY="${GLIBC_COMPATIBILITY}" \ + -DBUILD_META_TOOL=OFF \ -DWITH_MYSQL=OFF \ ${CMAKE_USE_CCACHE} ../ ${BUILD_SYSTEM} -j ${PARALLEL} $RUN_FILE --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org