This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch vector-index-dev
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/vector-index-dev by this push:
new ba9ec4b46a2 step forward on vector search (#51374)
ba9ec4b46a2 is described below
commit ba9ec4b46a2aeec84b12e2d1ca905ad580d50012
Author: zhiqiang <[email protected]>
AuthorDate: Fri May 30 12:02:29 2025 +0800
step forward on vector search (#51374)
## Fix
1. VirtualColumnIterator seek_to_ordinal should check its _max_ordinal.
2. Fix data race of ann_index_reader load_index.
## Functionality
1. Add l2_distance_approximate & inner_product_approximate.
2. Above xxx_approximate function will be pushed to index,
l2_distance/inner_product will do exhaustive search.
## Refactor
1. Rename AnnTopNDescriptor to AnnTopNRuntime
---
be/src/olap/iterators.h | 4 +-
be/src/olap/rowset/beta_rowset_reader.cpp | 2 +-
be/src/olap/rowset/rowset_reader_context.h | 4 +-
be/src/olap/rowset/segment_v2/ann_index_reader.cpp | 27 ++++----
be/src/olap/rowset/segment_v2/ann_index_reader.h | 4 ++
be/src/olap/rowset/segment_v2/segment_iterator.cpp | 45 ++++++++------
be/src/olap/rowset/segment_v2/segment_iterator.h | 4 +-
.../rowset/segment_v2/virtual_column_iterator.cpp | 7 ++-
.../rowset/segment_v2/virtual_column_iterator.h | 2 +-
be/src/olap/tablet_reader.cpp | 2 +-
be/src/olap/tablet_reader.h | 2 +-
be/src/pipeline/exec/olap_scan_operator.cpp | 13 ++--
be/src/pipeline/exec/operator.cpp | 11 ----
be/src/pipeline/exec/operator.h | 4 +-
be/src/pipeline/exec/scan_operator.cpp | 10 +--
be/src/vec/exec/scan/olap_scanner.cpp | 6 +-
be/src/vec/exec/scan/olap_scanner.h | 2 +-
...ann_topn_predicate.cpp => ann_topn_runtime.cpp} | 40 ++++++------
.../{vann_topn_predicate.h => ann_topn_runtime.h} | 8 +--
be/src/vec/exprs/vectorized_fn_call.cpp | 12 +++-
be/src/vec/exprs/vexpr.cpp | 2 +-
be/src/vec/exprs/virtual_slot_ref.h | 7 +--
.../vec/functions/array/function_array_distance.h | 2 -
.../array/function_array_distance_approximate.cpp | 29 +++++++++
...nce.h => function_array_distance_approximate.h} | 52 +++++-----------
.../functions/array/function_array_register.cpp | 2 +
.../olap/vector_search/ann_range_search_test.cpp | 6 +-
.../vector_search/ann_topn_descriptor_test.cpp | 28 ++++-----
.../olap/vector_search/faiss_vector_index_test.cpp | 2 +-
be/test/olap/vector_search/vector_search_utils.h | 4 +-
.../doris/catalog/BuiltinScalarFunctions.java | 4 ++
.../rewrite/PushDownVectorTopNIntoOlapScan.java | 10 +--
.../PushDownVirtualColumnsIntoOlapScan.java | 7 +--
.../functions/scalar/InnerProductApproximate.java | 71 ++++++++++++++++++++++
.../functions/scalar/L2DistanceApproximate.java | 71 ++++++++++++++++++++++
.../expressions/visitor/ScalarFunctionVisitor.java | 10 +++
36 files changed, 340 insertions(+), 176 deletions(-)
diff --git a/be/src/olap/iterators.h b/be/src/olap/iterators.h
index f683e55b1e4..7e182b371db 100644
--- a/be/src/olap/iterators.h
+++ b/be/src/olap/iterators.h
@@ -29,7 +29,7 @@
#include "olap/tablet_schema.h"
#include "runtime/runtime_state.h"
#include "vec/core/block.h"
-#include "vec/exprs/vann_topn_predicate.h"
+#include "vec/exprs/ann_topn_runtime.h"
#include "vec/exprs/vexpr.h"
namespace doris {
@@ -122,7 +122,7 @@ public:
size_t topn_limit = 0;
std::map<ColumnId, vectorized::VExprContextSPtr> virtual_column_exprs;
- std::shared_ptr<vectorized::AnnTopNDescriptor> ann_topn_descriptor;
+ std::shared_ptr<vectorized::AnnTopNRuntime> ann_topn_runtime;
std::map<ColumnId, size_t> vir_cid_to_idx_in_block;
std::map<size_t, vectorized::DataTypePtr> vir_col_idx_to_type;
};
diff --git a/be/src/olap/rowset/beta_rowset_reader.cpp
b/be/src/olap/rowset/beta_rowset_reader.cpp
index e12c89d056f..8c48572227e 100644
--- a/be/src/olap/rowset/beta_rowset_reader.cpp
+++ b/be/src/olap/rowset/beta_rowset_reader.cpp
@@ -102,7 +102,7 @@ Status
BetaRowsetReader::get_segment_iterators(RowsetReaderContext* read_context
_read_options.remaining_conjunct_roots =
_read_context->remaining_conjunct_roots;
_read_options.common_expr_ctxs_push_down =
_read_context->common_expr_ctxs_push_down;
_read_options.virtual_column_exprs = _read_context->virtual_column_exprs;
- _read_options.ann_topn_descriptor = _read_context->ann_topn_descriptor;
+ _read_options.ann_topn_runtime = _read_context->ann_topn_runtime;
_read_options.vir_cid_to_idx_in_block =
_read_context->vir_cid_to_idx_in_block;
_read_options.vir_col_idx_to_type = _read_context->vir_col_idx_to_type;
_read_options.rowset_id = _rowset->rowset_id();
diff --git a/be/src/olap/rowset/rowset_reader_context.h
b/be/src/olap/rowset/rowset_reader_context.h
index 8e9a62f5f90..d3f8110daae 100644
--- a/be/src/olap/rowset/rowset_reader_context.h
+++ b/be/src/olap/rowset/rowset_reader_context.h
@@ -23,7 +23,7 @@
#include "olap/olap_common.h"
#include "olap/rowid_conversion.h"
#include "runtime/runtime_state.h"
-#include "vec/exprs/vann_topn_predicate.h"
+#include "vec/exprs/ann_topn_runtime.h"
#include "vec/exprs/vexpr.h"
#include "vec/exprs/vexpr_context.h"
@@ -88,7 +88,7 @@ struct RowsetReaderContext {
std::map<ColumnId, size_t> vir_cid_to_idx_in_block;
std::map<size_t, vectorized::DataTypePtr> vir_col_idx_to_type;
- std::shared_ptr<vectorized::AnnTopNDescriptor> ann_topn_descriptor;
+ std::shared_ptr<vectorized::AnnTopNRuntime> ann_topn_runtime;
};
} // namespace doris
diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
index 64637a72566..d9d90dfbf04 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
@@ -22,9 +22,11 @@
#include "ann_index_iterator.h"
#include "common/config.h"
+#include "io/io_common.h"
#include "olap/rowset/segment_v2/index_file_reader.h"
#include "olap/rowset/segment_v2/inverted_index_compound_reader.h"
#include "runtime/runtime_state.h"
+#include "util/once.h"
#include "vector/faiss_vector_index.h"
#include "vector/vector_index.h"
@@ -63,20 +65,24 @@ Status AnnIndexReader::new_iterator(const io::IOContext&
io_ctx, OlapReaderStati
}
Status AnnIndexReader::load_index(io::IOContext* io_ctx) {
- Result<std::unique_ptr<DorisCompoundReader>> compound_dir =
- _index_file_reader->open(&_index_meta, io_ctx);
- if (!compound_dir.has_value()) {
- return Status::IOError("Failed to open index file: {}",
compound_dir.error().to_string());
- }
- _vector_index = std::make_unique<FaissVectorIndex>();
-
- RETURN_IF_ERROR(_vector_index->load(compound_dir->get()));
- return Status::OK();
+ return _load_index_once.call([&]() {
+
RETURN_IF_ERROR(_index_file_reader->init(config::inverted_index_read_buffer_size,
io_ctx));
+
+ Result<std::unique_ptr<DorisCompoundReader>> compound_dir =
+ _index_file_reader->open(&_index_meta, io_ctx);
+ if (!compound_dir.has_value()) {
+ return Status::IOError("Failed to open index file: {}",
+ compound_dir.error().to_string());
+ }
+ _vector_index = std::make_unique<FaissVectorIndex>();
+
+ RETURN_IF_ERROR(_vector_index->load(compound_dir->get()));
+ return Status::OK();
+ });
}
Status AnnIndexReader::query(io::IOContext* io_ctx, AnnIndexParam* param) {
#ifndef BE_TEST
-
RETURN_IF_ERROR(_index_file_reader->init(config::inverted_index_read_buffer_size,
io_ctx));
RETURN_IF_ERROR(load_index(io_ctx));
#endif
DCHECK(_vector_index != nullptr);
@@ -110,7 +116,6 @@ Status AnnIndexReader::range_search(const
RangeSearchParams& params,
const VectorSearchUserParams&
custom_params,
RangeSearchResult* result, io::IOContext*
io_ctx) {
#ifndef BE_TEST
-
RETURN_IF_ERROR(_index_file_reader->init(config::inverted_index_read_buffer_size,
io_ctx));
RETURN_IF_ERROR(load_index(io_ctx));
#endif
DCHECK(_vector_index != nullptr);
diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.h
b/be/src/olap/rowset/segment_v2/ann_index_reader.h
index a12c0a508e4..e557d9c96b3 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_reader.h
+++ b/be/src/olap/rowset/segment_v2/ann_index_reader.h
@@ -20,6 +20,7 @@
#include "olap/rowset/segment_v2/index_reader.h"
#include "olap/tablet_schema.h"
#include "runtime/runtime_state.h"
+#include "util/once.h"
#include "vector/vector_index.h"
namespace doris::segment_v2 {
@@ -53,6 +54,7 @@ public:
Status new_iterator(const io::IOContext& io_ctx, OlapReaderStatistics*
stats,
RuntimeState* runtime_state,
std::unique_ptr<IndexIterator>* iterator) override;
+
VectorIndex::Metric get_metric_type() const { return _metric_type; }
private:
@@ -62,6 +64,8 @@ private:
// TODO: Use integer.
std::string _index_type;
VectorIndex::Metric _metric_type;
+
+ DorisCallOnce<Status> _load_index_once;
};
using AnnIndexReaderPtr = std::shared_ptr<AnnIndexReader>;
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
index d77c72d8ffb..26a258ad865 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
@@ -91,7 +91,7 @@
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_factory.hpp"
#include "vec/data_types/data_type_number.h"
-#include "vec/exprs/vann_topn_predicate.h"
+#include "vec/exprs/ann_topn_runtime.h"
#include "vec/exprs/vexpr.h"
#include "vec/exprs/vexpr_context.h"
#include "vec/exprs/virtual_slot_ref.h"
@@ -319,7 +319,7 @@ Status SegmentIterator::_init_impl(const
StorageReadOptions& opts) {
opts.vir_cid_to_idx_in_block.size());
_virtual_column_exprs = _opts.virtual_column_exprs;
- _ann_topn_descriptor = _opts.ann_topn_descriptor;
+ _ann_topn_runtime = _opts.ann_topn_runtime;
_vir_cid_to_idx_in_block = _opts.vir_cid_to_idx_in_block;
RETURN_IF_ERROR(init_iterators());
@@ -372,9 +372,9 @@ Status SegmentIterator::_init_impl(const
StorageReadOptions& opts) {
RETURN_IF_ERROR(_construct_compound_expr_context());
_enable_common_expr_pushdown = !_common_expr_ctxs_push_down.empty();
LOG_INFO(
- "Segment iterator init, virtual_column_exprs size: {}, has
ann_topn_descriptor: {}, "
+ "Segment iterator init, virtual_column_exprs size: {}, has
ann_topn_runtime: {}, "
"_vir_cid_to_idx_in_block size: {}, common_expr_pushdown size: {}",
- _opts.virtual_column_exprs.size(), _opts.ann_topn_descriptor !=
nullptr,
+ _opts.virtual_column_exprs.size(), _opts.ann_topn_runtime !=
nullptr,
_opts.vir_cid_to_idx_in_block.size(),
_common_expr_ctxs_push_down.size());
_initialize_predicate_results();
return Status::OK();
@@ -616,14 +616,13 @@ Status
SegmentIterator::_get_row_ranges_by_column_conditions() {
}
Status SegmentIterator::_apply_ann_topn_predicate() {
- if (_ann_topn_descriptor == nullptr) {
+ if (_ann_topn_runtime == nullptr) {
return Status::OK();
}
- LOG_INFO("Ann topn descriptor: {}", _ann_topn_descriptor->debug_string());
- size_t src_col_idx = _ann_topn_descriptor->get_src_column_idx();
+ LOG_INFO("Try apply ann topn: {}", _ann_topn_runtime->debug_string());
+ size_t src_col_idx = _ann_topn_runtime->get_src_column_idx();
ColumnId src_cid = _schema->column_id(src_col_idx);
- LOG_INFO("Ann topn src column id: {}", src_cid);
IndexIterator* ann_index_iterator = _index_iterators[src_cid].get();
if (ann_index_iterator == nullptr || !_common_expr_ctxs_push_down.empty()
||
@@ -641,32 +640,41 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
auto ann_index_reader = dynamic_cast<AnnIndexReader*>(index_reader.get());
DCHECK(ann_index_reader != nullptr);
if (ann_index_reader->get_metric_type() ==
VectorIndex::Metric::INNER_PRODUCT) {
- if (_ann_topn_descriptor->is_asc()) {
- LOG_INFO("asc topn for inner product can not be evaluated by ann
index");
+ if (_ann_topn_runtime->is_asc()) {
+ LOG_INFO("Asc topn for inner product can not be evaluated by ann
index");
return Status::OK();
}
} else {
- if (!_ann_topn_descriptor->is_asc()) {
- LOG_INFO("desc topn for l2/cosine can not be evaluated by ann
index");
+ if (!_ann_topn_runtime->is_asc()) {
+ LOG_INFO("Desc topn for l2/cosine can not be evaluated by ann
index");
return Status::OK();
}
}
- if (ann_index_reader->get_metric_type() !=
_ann_topn_descriptor->get_metric_type()) {
+ if (ann_index_reader->get_metric_type() !=
_ann_topn_runtime->get_metric_type()) {
LOG_INFO(
"Ann topn metric type {} not match index metric type {}, can
not be evaluated by "
"ann index",
-
VectorIndex::metric_to_string(_ann_topn_descriptor->get_metric_type()),
+
VectorIndex::metric_to_string(_ann_topn_runtime->get_metric_type()),
VectorIndex::metric_to_string(ann_index_reader->get_metric_type()));
return Status::OK();
}
size_t pre_size = _row_bitmap.cardinality();
- size_t dst_col_idx = _ann_topn_descriptor->get_dest_column_idx();
+ size_t rows_of_semgnet = _segment->num_rows();
+ if (pre_size < rows_of_semgnet * 0.3) {
+ LOG_INFO(
+ "Ann topn predicate input rows {} < 30% of segment rows {},
will not use ann index "
+ "to "
+ "filter",
+ pre_size, rows_of_semgnet);
+ return Status::OK();
+ }
+ const size_t dst_col_idx = _ann_topn_runtime->get_dest_column_idx();
vectorized::IColumn::MutablePtr result_column;
std::unique_ptr<std::vector<uint64_t>> result_row_ids;
- RETURN_IF_ERROR(_ann_topn_descriptor->evaluate_vector_ann_search(
- ann_index_iterator, _row_bitmap, result_column, result_row_ids));
+
RETURN_IF_ERROR(_ann_topn_runtime->evaluate_vector_ann_search(ann_index_iterator,
_row_bitmap,
+
result_column, result_row_ids));
// TODO: 处理 nullable
LOG_INFO("Ann topn filtered {} - {} = {} rows", pre_size,
_row_bitmap.cardinality(),
pre_size - _row_bitmap.cardinality());
@@ -1246,6 +1254,7 @@ Status SegmentIterator::_init_index_iterators() {
return Status::OK();
}
+ // Inverted index iterators
for (auto cid : _schema->column_ids()) {
// Use segment’s own index_meta, for compatibility with future
indexing needs to default to lowercase.
if (_index_iterators[cid] == nullptr) {
@@ -1262,7 +1271,7 @@ Status SegmentIterator::_init_index_iterators() {
}
}
- // TODO: tablet_schema 管理 index_meta 的逻辑很乱很奇怪,需要重构
+ // Ann index iterators
for (auto cid : _schema->column_ids()) {
if (_index_iterators[cid] == nullptr) {
const auto& column = _opts.tablet_schema->column(cid);
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.h
b/be/src/olap/rowset/segment_v2/segment_iterator.h
index bb868d6278a..66418807ed9 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.h
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.h
@@ -52,7 +52,7 @@
#include "vec/core/column_with_type_and_name.h"
#include "vec/core/columns_with_type_and_name.h"
#include "vec/data_types/data_type.h"
-#include "vec/exprs/vann_topn_predicate.h"
+#include "vec/exprs/ann_topn_runtime.h"
#include "vec/exprs/vexpr_fwd.h"
namespace doris {
@@ -480,7 +480,7 @@ private:
std::unordered_map<ColumnId, std::unordered_map<const vectorized::VExpr*,
bool>>
_common_expr_inverted_index_status;
- std::shared_ptr<vectorized::AnnTopNDescriptor> _ann_topn_descriptor;
+ std::shared_ptr<vectorized::AnnTopNRuntime> _ann_topn_runtime;
std::map<ColumnId, vectorized::VExprContextSPtr> _virtual_column_exprs;
std::map<ColumnId, size_t> _vir_cid_to_idx_in_block;
diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp
b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp
index 73285504624..624cc484717 100644
--- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp
@@ -56,7 +56,7 @@ void
VirtualColumnIterator::prepare_materialization(vectorized::IColumn::Ptr col
// orders: [1,2,4,5,7,10]
std::sort(order.begin(), order.end(), [&](size_t a, size_t b) { return a <
b; });
- LOG_INFO("Sorted order {}", fmt::join(order, ", "));
+ _max_ordinal = order[n - 1];
// 2. scatter column
auto scattered_column = column->clone_empty();
// We need a mapping from global row id to local index in the materialized
column.
@@ -101,8 +101,9 @@ Status VirtualColumnIterator::seek_to_ordinal(ordinal_t
ord_idx) {
return Status::OK();
}
- if (ord_idx >= _size) {
- return Status::InternalError("Seek to ordinal out of range: {} out of
{}", ord_idx, _size);
+ if (ord_idx >= _max_ordinal) {
+ return Status::InternalError("Seek to ordinal out of range: {} out of
{}", ord_idx,
+ _max_ordinal);
}
_current_ordinal = ord_idx;
diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.h
b/be/src/olap/rowset/segment_v2/virtual_column_iterator.h
index f8c5f360716..cfdd59745d8 100644
--- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.h
+++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.h
@@ -59,7 +59,7 @@ private:
std::map<uint64_t, uint64_t> _row_id_to_idx;
doris::vectorized::IColumn::Filter _filter;
size_t _size = 0;
-
+ size_t _max_ordinal = 0;
ordinal_t _current_ordinal = 0;
};
diff --git a/be/src/olap/tablet_reader.cpp b/be/src/olap/tablet_reader.cpp
index f09f0ce20d0..2d427c63ad1 100644
--- a/be/src/olap/tablet_reader.cpp
+++ b/be/src/olap/tablet_reader.cpp
@@ -261,7 +261,7 @@ Status TabletReader::_capture_rs_readers(const
ReaderParams& read_params) {
_reader_context.remaining_conjunct_roots =
read_params.remaining_conjunct_roots;
_reader_context.common_expr_ctxs_push_down =
read_params.common_expr_ctxs_push_down;
_reader_context.virtual_column_exprs = read_params.virtual_column_exprs;
- _reader_context.ann_topn_descriptor = read_params.ann_topn_descriptor;
+ _reader_context.ann_topn_runtime = read_params.ann_topn_runtime;
_reader_context.vir_cid_to_idx_in_block =
read_params.vir_cid_to_idx_in_block;
_reader_context.vir_col_idx_to_type = read_params.vir_col_idx_to_type;
_reader_context.output_columns = &read_params.output_columns;
diff --git a/be/src/olap/tablet_reader.h b/be/src/olap/tablet_reader.h
index 825842423c7..4faa1b67fb4 100644
--- a/be/src/olap/tablet_reader.h
+++ b/be/src/olap/tablet_reader.h
@@ -196,7 +196,7 @@ public:
int64_t batch_size = -1;
std::map<ColumnId, vectorized::VExprContextSPtr> virtual_column_exprs;
- std::shared_ptr<vectorized::AnnTopNDescriptor> ann_topn_descriptor;
+ std::shared_ptr<vectorized::AnnTopNRuntime> ann_topn_runtime;
std::map<ColumnId, size_t> vir_cid_to_idx_in_block;
std::map<size_t, vectorized::DataTypePtr> vir_col_idx_to_type;
};
diff --git a/be/src/pipeline/exec/olap_scan_operator.cpp
b/be/src/pipeline/exec/olap_scan_operator.cpp
index 7861d95ec9d..3d45e931e6e 100644
--- a/be/src/pipeline/exec/olap_scan_operator.cpp
+++ b/be/src/pipeline/exec/olap_scan_operator.cpp
@@ -38,7 +38,7 @@
#include "util/runtime_profile.h"
#include "util/to_string.h"
#include "vec/exec/scan/olap_scanner.h"
-#include "vec/exprs/vann_topn_predicate.h"
+#include "vec/exprs/ann_topn_runtime.h"
#include "vec/exprs/vectorized_fn_call.h"
#include "vec/exprs/vexpr.h"
#include "vec/exprs/vexpr_context.h"
@@ -358,7 +358,6 @@ Status
OlapScanLocalState::_init_scanners(std::list<vectorized::ScannerSPtr>* sc
state()->query_options().resource_limit.__isset.cpu_limit;
RETURN_IF_ERROR(hold_tablets());
- LOG_INFO("ScanNode is_preaggregation: {}",
p._olap_scan_node.is_preaggregation);
if (enable_parallel_scan && !p._should_run_serial && !has_cpu_limit &&
p._push_down_agg_type == TPushAggOp::NONE &&
(_storage_no_merge() || p._olap_scan_node.is_preaggregation)) {
@@ -560,8 +559,6 @@ Status OlapScanLocalState::init(RuntimeState* state,
LocalStateInfo& info) {
const TOlapScanNode& olap_scan_node =
_parent->cast<OlapScanOperatorX>()._olap_scan_node;
if (olap_scan_node.__isset.ann_sort_info ||
olap_scan_node.__isset.ann_sort_limit) {
- LOG_INFO("Ann sort info: {}",
-
apache::thrift::ThriftDebugString(olap_scan_node.ann_sort_info));
DCHECK(olap_scan_node.__isset.ann_sort_info);
DCHECK(olap_scan_node.__isset.ann_sort_limit);
DCHECK(olap_scan_node.ann_sort_info.ordering_exprs.size() == 1);
@@ -574,8 +571,8 @@ Status OlapScanLocalState::init(RuntimeState* state,
LocalStateInfo& info) {
const size_t limit = olap_scan_node.ann_sort_limit;
std::shared_ptr<vectorized::VExprContext> ordering_expr_ctx;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(ordering_expr,
ordering_expr_ctx));
- _ann_topn_descriptor =
- vectorized::AnnTopNDescriptor::create_shared(asc, limit,
ordering_expr_ctx);
+ _ann_topn_runtime =
+ vectorized::AnnTopNRuntime::create_shared(asc, limit,
ordering_expr_ctx);
}
return ScanLocalState<OlapScanLocalState>::init(state, info);
@@ -613,8 +610,8 @@ Status OlapScanLocalState::open(RuntimeState* state) {
}
}
- if (_ann_topn_descriptor) {
- RETURN_IF_ERROR(_ann_topn_descriptor->prepare(state,
p.intermediate_row_desc()));
+ if (_ann_topn_runtime) {
+ RETURN_IF_ERROR(_ann_topn_runtime->prepare(state,
p.intermediate_row_desc()));
}
RETURN_IF_ERROR(ScanLocalState<OlapScanLocalState>::open(state));
diff --git a/be/src/pipeline/exec/operator.cpp
b/be/src/pipeline/exec/operator.cpp
index 4059de9c2a9..adc2fde01ee 100644
--- a/be/src/pipeline/exec/operator.cpp
+++ b/be/src/pipeline/exec/operator.cpp
@@ -192,25 +192,14 @@ Status OperatorXBase::init(const TPlanNode& tnode,
RuntimeState* /*state*/) {
auto substr = node_name.substr(0, node_name.find("_NODE"));
_op_name = substr + "_OPERATOR";
- LOG_INFO("Conjunct size of {} is {}", _op_name,
- tnode.__isset.conjuncts ? tnode.conjuncts.size() : 0);
-
if (tnode.__isset.vconjunct) {
vectorized::VExprContextSPtr context;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(tnode.vconjunct,
context));
- // LOG_INFO("Conjunct of {} is\n{}", _op_name,
- // apache::thrift::ThriftDebugString(tnode.vconjunct));
_conjuncts.emplace_back(context);
} else if (tnode.__isset.conjuncts) {
for (auto& conjunct : tnode.conjuncts) {
vectorized::VExprContextSPtr context;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(conjunct,
context));
- // LOG_INFO("Conjunct of {} is\n{}", _op_name,
- // apache::thrift::ThriftDebugString(conjunct));
- // // Write the conjunct to a file for debugging
- // doris::vectorized::write_to_json(
- //
"/mnt/disk4/hezhiqiang/workspace/doris/cmaster/RELEASE/be1", "conjunct.json",
- // conjunct);
_conjuncts.emplace_back(context);
}
}
diff --git a/be/src/pipeline/exec/operator.h b/be/src/pipeline/exec/operator.h
index 721924cf25c..19105c957db 100644
--- a/be/src/pipeline/exec/operator.h
+++ b/be/src/pipeline/exec/operator.h
@@ -50,7 +50,7 @@ class RuntimeState;
class TDataSink;
namespace vectorized {
class AsyncResultWriter;
-class AnnTopNDescriptor;
+class AnnTopNRuntime;
} // namespace vectorized
} // namespace doris
@@ -245,7 +245,7 @@ protected:
RuntimeState* _state = nullptr;
vectorized::VExprContextSPtrs _conjuncts;
vectorized::VExprContextSPtrs _projections;
- std::shared_ptr<vectorized::AnnTopNDescriptor> _ann_topn_descriptor;
+ std::shared_ptr<vectorized::AnnTopNRuntime> _ann_topn_runtime;
// Used in common subexpression elimination to compute intermediate
results.
std::vector<vectorized::VExprContextSPtrs> _intermediate_projections;
diff --git a/be/src/pipeline/exec/scan_operator.cpp
b/be/src/pipeline/exec/scan_operator.cpp
index 30ad7d02928..9ad39866fe5 100644
--- a/be/src/pipeline/exec/scan_operator.cpp
+++ b/be/src/pipeline/exec/scan_operator.cpp
@@ -1247,15 +1247,7 @@ template <typename LocalStateType>
Status ScanOperatorX<LocalStateType>::prepare(RuntimeState* state) {
_input_tuple_desc =
state->desc_tbl().get_tuple_descriptor(_input_tuple_id);
_output_tuple_desc =
state->desc_tbl().get_tuple_descriptor(_output_tuple_id);
- LOG_INFO(
- "ScanOperator, _input_tuple_id:{}, _input_tuple_desc.slots:{},
_output_tuple_id:{}, "
- "_output_tuple_desc.slots:{}",
- _input_tuple_id,
- _input_tuple_desc == nullptr ? -1
- :
static_cast<int32>(_input_tuple_desc->slots().size()),
- _output_tuple_id,
- _output_tuple_desc == nullptr ? -1
- :
static_cast<int32>(_output_tuple_desc->slots().size()));
+
RETURN_IF_ERROR(OperatorX<LocalStateType>::prepare(state));
const auto slots = _output_tuple_desc->slots();
diff --git a/be/src/vec/exec/scan/olap_scanner.cpp
b/be/src/vec/exec/scan/olap_scanner.cpp
index 9c6197cce0e..c4847c49ded 100644
--- a/be/src/vec/exec/scan/olap_scanner.cpp
+++ b/be/src/vec/exec/scan/olap_scanner.cpp
@@ -98,7 +98,7 @@ OlapScanner::OlapScanner(pipeline::ScanLocalStateBase*
parent, OlapScanner::Para
.filter_block_conjuncts {},
.key_group_cluster_key_idxes {},
.virtual_column_exprs {},
- .ann_topn_descriptor {},
+ .ann_topn_runtime {},
.vir_cid_to_idx_in_block {},
.vir_col_idx_to_type {},
}) {
@@ -161,7 +161,7 @@ Status OlapScanner::init() {
_slot_id_to_index_in_block = local_state->_slot_id_to_index_in_block;
_slot_id_to_col_type = local_state->_slot_id_to_col_type;
- _ann_topn_descriptor = local_state->_ann_topn_descriptor;
+ _ann_topn_runtime = local_state->_ann_topn_runtime;
// set limit to reduce end of rowset and segment mem use
_tablet_reader = std::make_unique<BlockReader>();
@@ -321,7 +321,7 @@ Status OlapScanner::_init_tablet_reader_params(
_tablet_reader_params.common_expr_ctxs_push_down =
_common_expr_ctxs_push_down;
_tablet_reader_params.virtual_column_exprs = _virtual_column_exprs;
- _tablet_reader_params.ann_topn_descriptor = _ann_topn_descriptor;
+ _tablet_reader_params.ann_topn_runtime = _ann_topn_runtime;
_tablet_reader_params.vir_cid_to_idx_in_block = _vir_cid_to_idx_in_block;
_tablet_reader_params.vir_col_idx_to_type = _vir_col_idx_to_type;
_tablet_reader_params.output_columns =
diff --git a/be/src/vec/exec/scan/olap_scanner.h
b/be/src/vec/exec/scan/olap_scanner.h
index f6895662c88..efa3f029f21 100644
--- a/be/src/vec/exec/scan/olap_scanner.h
+++ b/be/src/vec/exec/scan/olap_scanner.h
@@ -118,7 +118,7 @@ public:
// The idx of vir_col in block to its data type.
std::map<size_t, vectorized::DataTypePtr> _vir_col_idx_to_type;
- std::shared_ptr<vectorized::AnnTopNDescriptor> _ann_topn_descriptor;
+ std::shared_ptr<vectorized::AnnTopNRuntime> _ann_topn_runtime;
VectorSearchUserParams _vector_search_params;
};
diff --git a/be/src/vec/exprs/vann_topn_predicate.cpp
b/be/src/vec/exprs/ann_topn_runtime.cpp
similarity index 82%
rename from be/src/vec/exprs/vann_topn_predicate.cpp
rename to be/src/vec/exprs/ann_topn_runtime.cpp
index 1f8a9bd7047..5f86a2a9241 100644
--- a/be/src/vec/exprs/vann_topn_predicate.cpp
+++ b/be/src/vec/exprs/ann_topn_runtime.cpp
@@ -15,11 +15,10 @@
// specific language governing permissions and limitations
// under the License.
-#include "vec/exprs/vann_topn_predicate.h"
+#include "vec/exprs/ann_topn_runtime.h"
#include <cstdint>
#include <memory>
-#include <sstream>
#include <string>
#include "common/logging.h"
@@ -36,10 +35,11 @@
#include "vec/exprs/vexpr_fwd.h"
#include "vec/exprs/virtual_slot_ref.h"
#include "vec/exprs/vslot_ref.h"
+#include "vec/functions/array/function_array_distance_approximate.h"
namespace doris::vectorized {
-Status AnnTopNDescriptor::prepare(RuntimeState* state, const RowDescriptor&
row_desc) {
+Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor&
row_desc) {
RETURN_IF_ERROR(_order_by_expr_ctx->prepare(state, row_desc));
RETURN_IF_ERROR(_order_by_expr_ctx->open(state));
@@ -104,22 +104,26 @@ Status AnnTopNDescriptor::prepare(RuntimeState* state,
const RowDescriptor& row_
_query_array = array_literal->get_column_ptr();
_user_params = state->get_vector_search_params();
- std::set<std::string> distance_func_names = {vectorized::L2Distance::name,
-
vectorized::InnerProduct::name};
+ std::set<std::string> distance_func_names =
{vectorized::L2DistanceApproximate::name,
+
vectorized::InnerProductApproximate::name};
if (distance_func_names.contains(distance_fn_call->function_name()) ==
false) {
return Status::InternalError("Ann topn expr expect distance function,
got {}",
distance_fn_call->function_name());
}
+ std::string metric_name = distance_fn_call->function_name();
+ // Strip the "_approximate" suffix
+ metric_name = metric_name.substr(0, metric_name.size() - 12);
- _metric_type =
segment_v2::VectorIndex::string_to_metric(distance_fn_call->function_name());
- VLOG_DEBUG << "AnnTopNDescriptor: {}" << this->debug_string();
+ _metric_type = segment_v2::VectorIndex::string_to_metric(metric_name);
+
+ VLOG_DEBUG << "AnnTopNRuntime: {}" << this->debug_string();
return Status::OK();
}
-Status AnnTopNDescriptor::evaluate_vector_ann_search(
- segment_v2::IndexIterator* ann_index_iterator, roaring::Roaring&
roaring,
- vectorized::IColumn::MutablePtr& result_column,
- std::unique_ptr<std::vector<uint64_t>>& row_ids) {
+Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator*
ann_index_iterator,
+ roaring::Roaring& roaring,
+
vectorized::IColumn::MutablePtr& result_column,
+
std::unique_ptr<std::vector<uint64_t>>& row_ids) {
DCHECK(ann_index_iterator != nullptr);
segment_v2::AnnIndexIterator* ann_index_iterator_casted =
dynamic_cast<segment_v2::AnnIndexIterator*>(ann_index_iterator);
@@ -169,12 +173,12 @@ Status AnnTopNDescriptor::evaluate_vector_ann_search(
return Status::OK();
}
-std::string AnnTopNDescriptor::debug_string() const {
- return "AnnTopNDescriptor: limit=" + std::to_string(_limit) +
- ", src_col_idx=" + std::to_string(_src_column_idx) +
- ", dest_col_idx=" + std::to_string(_dest_column_idx) + ", asc=" +
std::to_string(_asc) +
- ", user_params=" + _user_params.to_string() +
- ", metric_type=" +
segment_v2::VectorIndex::metric_to_string(_metric_type) +
- ", order_by_expr=" + _order_by_expr_ctx->root()->debug_string();
+std::string AnnTopNRuntime::debug_string() const {
+ return fmt::format(
+ "AnnTopNRuntime: limit={}, src_col_idx={}, dest_col_idx={},
asc={}, user_params={}, "
+ "metric_type={}, order_by_expr={}",
+ _limit, _src_column_idx, _dest_column_idx, _asc,
_user_params.to_string(),
+ segment_v2::VectorIndex::metric_to_string(_metric_type),
+ _order_by_expr_ctx->root()->debug_string());
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/exprs/vann_topn_predicate.h
b/be/src/vec/exprs/ann_topn_runtime.h
similarity index 91%
rename from be/src/vec/exprs/vann_topn_predicate.h
rename to be/src/vec/exprs/ann_topn_runtime.h
index 842d054cfcd..f270799cad2 100644
--- a/be/src/vec/exprs/vann_topn_predicate.h
+++ b/be/src/vec/exprs/ann_topn_runtime.h
@@ -29,11 +29,11 @@
namespace doris::vectorized {
-class AnnTopNDescriptor {
- ENABLE_FACTORY_CREATOR(AnnTopNDescriptor);
+class AnnTopNRuntime {
+ ENABLE_FACTORY_CREATOR(AnnTopNRuntime);
public:
- AnnTopNDescriptor(bool asc, size_t limit, VExprContextSPtr
order_by_expr_ctx)
+ AnnTopNRuntime(bool asc, size_t limit, VExprContextSPtr order_by_expr_ctx)
: _asc(asc), _limit(limit), _order_by_expr_ctx(order_by_expr_ctx)
{};
Status prepare(RuntimeState* state, const RowDescriptor& row_desc);
@@ -59,7 +59,7 @@ private:
// order by distance(xxx, [1,2])
VExprContextSPtr _order_by_expr_ctx;
- std::string _name = "AnnTopNDescriptor";
+ std::string _name = "ann_topn_runtime";
size_t _src_column_idx = -1;
size_t _dest_column_idx = -1;
segment_v2::VectorIndex::Metric _metric_type;
diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp
b/be/src/vec/exprs/vectorized_fn_call.cpp
index c6fa1a45d55..49bfef7bc55 100644
--- a/be/src/vec/exprs/vectorized_fn_call.cpp
+++ b/be/src/vec/exprs/vectorized_fn_call.cpp
@@ -51,6 +51,7 @@
#include "vec/exprs/virtual_slot_ref.h"
#include "vec/exprs/vliteral.h"
#include "vec/functions/array/function_array_distance.h"
+#include "vec/functions/array/function_array_distance_approximate.h"
#include "vec/functions/function_agg_state.h"
#include "vec/functions/function_fake.h"
#include "vec/functions/function_java_udf.h"
@@ -382,13 +383,18 @@ Status VectorizedFnCall::prepare_ann_range_search(
}
// Now left child is a function call, we need to check if it is a distance
function
- std::set<std::string> distance_functions = {L2Distance::name,
InnerProduct::name};
+ std::set<std::string> distance_functions = {L2DistanceApproximate::name,
+ InnerProductApproximate::name};
if (distance_functions.find(function_call->_function_name) ==
distance_functions.end()) {
- LOG_INFO("Left child is not a distance function. Got {}",
function_call->_function_name);
+ LOG_INFO("Left child is not a approximate distance function. Got {}",
+ function_call->_function_name);
return Status::OK();
} else {
+ // Strip the _approximate suffix.
+ std::string metric_name = function_call->_function_name;
+ metric_name = metric_name.substr(0, metric_name.size() - 12);
_ann_range_search_params.metric_type =
-
segment_v2::VectorIndex::string_to_metric(function_call->_function_name);
+ segment_v2::VectorIndex::string_to_metric(metric_name);
}
if (function_call->get_num_children() != 2) {
diff --git a/be/src/vec/exprs/vexpr.cpp b/be/src/vec/exprs/vexpr.cpp
index f3269dfca36..21eef1c2d9e 100644
--- a/be/src/vec/exprs/vexpr.cpp
+++ b/be/src/vec/exprs/vexpr.cpp
@@ -41,7 +41,7 @@
#include "vec/data_types/data_type_factory.hpp"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
-#include "vec/exprs/vann_topn_predicate.h"
+#include "vec/exprs/ann_topn_runtime.h"
#include "vec/exprs/varray_literal.h"
#include "vec/exprs/vcase_expr.h"
#include "vec/exprs/vcast_expr.h"
diff --git a/be/src/vec/exprs/virtual_slot_ref.h
b/be/src/vec/exprs/virtual_slot_ref.h
index 9bd184cd5b7..5a45267082b 100644
--- a/be/src/vec/exprs/virtual_slot_ref.h
+++ b/be/src/vec/exprs/virtual_slot_ref.h
@@ -16,7 +16,6 @@
// under the License.
#pragma once
-#include <cstdint>
#include "vec/exprs/vexpr.h"
@@ -45,8 +44,6 @@ public:
column_ids.insert(_column_id);
}
std::shared_ptr<VExpr> get_virtual_column_expr() const { return
_virtual_column_expr; }
- // void prepare_virtual_slots(const std::map<SlotId,
vectorized::VExprContextSPtr>&
- // _slot_id_to_virtual_column_expr)
override;
/*
select * from tbl where distance_function(columnA, ArrayLiterat) > 100
@@ -56,7 +53,7 @@ public:
BINARY_PRED
|---------------------------------------|
| |
- FUNCTION_CALL(l2_distance) IntLiteral
+ FUNCTION_CALL(l2_distance_approximate) IntLiteral
|
|-----------------------|
| |
@@ -70,7 +67,7 @@ public:
| |
VIRTUAL_SLOT_REF IntLiteral
|
- FUNCTION_CALL(l2_distance)
+ FUNCTION_CALL(l2_distance_approximate)
|
|-----------------------|
| |
diff --git a/be/src/vec/functions/array/function_array_distance.h
b/be/src/vec/functions/array/function_array_distance.h
index 8a2d533fad0..28b0df28d7f 100644
--- a/be/src/vec/functions/array/function_array_distance.h
+++ b/be/src/vec/functions/array/function_array_distance.h
@@ -96,8 +96,6 @@ public:
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const
override {
- // LOG_INFO("Function {} is executed with {} rows, stack {}",
get_name(), input_rows_count,
- // doris::get_stack_trace());
const auto& arg1 = block.get_by_position(arguments[0]);
const auto& arg2 = block.get_by_position(arguments[1]);
if (!_check_input_type(arg1.type) || !_check_input_type(arg2.type)) {
diff --git a/be/src/vec/functions/array/function_array_distance_approximate.cpp
b/be/src/vec/functions/array/function_array_distance_approximate.cpp
new file mode 100644
index 00000000000..cd5e4ee2eee
--- /dev/null
+++ b/be/src/vec/functions/array/function_array_distance_approximate.cpp
@@ -0,0 +1,29 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/functions/array/function_array_distance_approximate.h"
+
+#include "vec/functions/simple_function_factory.h"
+
+namespace doris::vectorized {
+
+void register_function_array_distance_approximate(SimpleFunctionFactory&
factory) {
+
factory.register_function<FunctionArrayDistanceApproximate<L2DistanceApproximate>>();
+
factory.register_function<FunctionArrayDistanceApproximate<InnerProductApproximate>>();
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/functions/array/function_array_distance.h
b/be/src/vec/functions/array/function_array_distance_approximate.h
similarity index 83%
copy from be/src/vec/functions/array/function_array_distance.h
copy to be/src/vec/functions/array/function_array_distance_approximate.h
index 8a2d533fad0..5e4415f0243 100644
--- a/be/src/vec/functions/array/function_array_distance.h
+++ b/be/src/vec/functions/array/function_array_distance_approximate.h
@@ -32,29 +32,23 @@
namespace doris::vectorized {
-class L1Distance {
+class L2DistanceApproximate {
public:
- static constexpr auto name = "l1_distance";
+ static constexpr auto name = "l2_distance_approximate";
struct State {
double sum = 0;
+ size_t count = 0;
};
- static void accumulate(State& state, double x, double y) { state.sum +=
fabs(x - y); }
- static double finalize(const State& state) { return state.sum; }
-};
-
-class L2Distance {
-public:
- static constexpr auto name = "l2_distance";
- struct State {
- double sum = 0;
- };
- static void accumulate(State& state, double x, double y) { state.sum += (x
- y) * (x - y); }
- static double finalize(const State& state) { return sqrt(state.sum); }
+ static void accumulate(State& state, double x, double y) {
+ state.sum += (x - y) * (x - y);
+ state.count++;
+ }
+ static double finalize(const State& state) { return sqrt(state.sum /
state.count); }
};
-class InnerProduct {
+class InnerProductApproximate {
public:
- static constexpr auto name = "inner_product";
+ static constexpr auto name = "inner_product_approximate";
struct State {
double sum = 0;
};
@@ -62,30 +56,14 @@ public:
static double finalize(const State& state) { return state.sum; }
};
-class CosineDistance {
-public:
- static constexpr auto name = "cosine_distance";
- struct State {
- double dot_prod = 0;
- double squared_x = 0;
- double squared_y = 0;
- };
- static void accumulate(State& state, double x, double y) {
- state.dot_prod += x * y;
- state.squared_x += x * x;
- state.squared_y += y * y;
- }
- static double finalize(const State& state) {
- return 1 - state.dot_prod / sqrt(state.squared_x * state.squared_y);
- }
-};
-
template <typename DistanceImpl>
-class FunctionArrayDistance : public IFunction {
+class FunctionArrayDistanceApproximate : public IFunction {
public:
static constexpr auto name = DistanceImpl::name;
String get_name() const override { return name; }
- static FunctionPtr create() { return
std::make_shared<FunctionArrayDistance<DistanceImpl>>(); }
+ static FunctionPtr create() {
+ return
std::make_shared<FunctionArrayDistanceApproximate<DistanceImpl>>();
+ }
bool is_variadic() const override { return false; }
size_t get_number_of_arguments() const override { return 2; }
bool use_default_implementation_for_nulls() const override { return false;
}
@@ -96,8 +74,6 @@ public:
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const
override {
- // LOG_INFO("Function {} is executed with {} rows, stack {}",
get_name(), input_rows_count,
- // doris::get_stack_trace());
const auto& arg1 = block.get_by_position(arguments[0]);
const auto& arg2 = block.get_by_position(arguments[1]);
if (!_check_input_type(arg1.type) || !_check_input_type(arg2.type)) {
diff --git a/be/src/vec/functions/array/function_array_register.cpp
b/be/src/vec/functions/array/function_array_register.cpp
index aa92e89128f..b9fceca8a76 100644
--- a/be/src/vec/functions/array/function_array_register.cpp
+++ b/be/src/vec/functions/array/function_array_register.cpp
@@ -28,6 +28,7 @@ void register_function_array_element(SimpleFunctionFactory&);
void register_function_array_index(SimpleFunctionFactory&);
void register_function_array_aggregation(SimpleFunctionFactory&);
void register_function_array_distance(SimpleFunctionFactory&);
+void register_function_array_distance_approximate(SimpleFunctionFactory&);
void register_function_array_distinct(SimpleFunctionFactory&);
void register_function_array_remove(SimpleFunctionFactory&);
void register_function_array_sort(SimpleFunctionFactory&);
@@ -66,6 +67,7 @@ void register_function_array(SimpleFunctionFactory& factory) {
register_function_array_index(factory);
register_function_array_aggregation(factory);
register_function_array_distance(factory);
+ register_function_array_distance_approximate(factory);
register_function_array_distinct(factory);
register_function_array_remove(factory);
register_function_array_sort(factory);
diff --git a/be/test/olap/vector_search/ann_range_search_test.cpp
b/be/test/olap/vector_search/ann_range_search_test.cpp
index 475dd2dad17..f73d1df2606 100644
--- a/be/test/olap/vector_search/ann_range_search_test.cpp
+++ b/be/test/olap/vector_search/ann_range_search_test.cpp
@@ -300,7 +300,7 @@ TDescriptorTable {
20: output_scale (i32) = -1,
26: fn (struct) = TFunction {
01: name (struct) = TFunctionName {
- 02: function_name (string) = "l2_distance",
+ 02: function_name (string) = "l2_distance_approximate",
},
02: binary_type (i32) = 0,
03: arg_types (list) = list<struct>[2] {
@@ -353,7 +353,7 @@ TDescriptorTable {
03: byte_size (i64) = -1,
},
05: has_var_args (bool) = false,
- 07: signature (string) = "l2_distance(array<double>,
array<double>)",
+ 07: signature (string) = "l2_distance_approximate(array<double>,
array<double>)",
09: scalar_fn (struct) = TScalarFunction {
01: symbol (string) = "",
},
@@ -810,7 +810,7 @@ TDescriptorTable {
}
*/
const std::string thrift_table_desc =
-
R"xxx({"1":{"lst":["rec",8,{"1":{"i32":0},"2":{"i32":0},"3":{"rec":{"1":{"lst":["rec",1,{"1":{"i32":0},"2":{"rec":{"1":{"i32":5}}}}]},"3":{"i64":-1}}},"4":{"i32":-1},"5":{"i32":-1},"6":{"i32":0},"7":{"i32":0},"8":{"str":"siteid"},"9":{"i32":0},"10":{"tf":1},"11":{"i32":0},"12":{"tf":1},"13":{"tf":1},"14":{"tf":0},"16":{"str":"10"},"17":{"i32":5}},{"1":{"i32":1},"2":{"i32":0},"3":{"rec":{"1":{"lst":["rec",2,{"1":{"i32":1},"4":{"tf":1},"5":{"lst":["tf",1,1]}},{"1":{"i32":0},"2":{"r
[...]
+
R"xxx({"1":{"lst":["rec",8,{"1":{"i32":0},"2":{"i32":0},"3":{"rec":{"1":{"lst":["rec",1,{"1":{"i32":0},"2":{"rec":{"1":{"i32":5}}}}]},"3":{"i64":-1}}},"4":{"i32":-1},"5":{"i32":-1},"6":{"i32":0},"7":{"i32":0},"8":{"str":"siteid"},"9":{"i32":0},"10":{"tf":1},"11":{"i32":0},"12":{"tf":1},"13":{"tf":1},"14":{"tf":0},"16":{"str":"10"},"17":{"i32":5}},{"1":{"i32":1},"2":{"i32":0},"3":{"rec":{"1":{"lst":["rec",2,{"1":{"i32":1},"4":{"tf":1},"5":{"lst":["tf",1,1]}},{"1":{"i32":0},"2":{"r
[...]
TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) {
TExpr texpr = read_from_json<TExpr>(ann_range_search_thrift);
diff --git a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
index e6bc46c5d97..a562c7d7686 100644
--- a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
+++ b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
@@ -26,7 +26,7 @@
#include <iostream>
#include <memory>
-#include "vec/exprs/vann_topn_predicate.h"
+#include "vec/exprs/ann_topn_runtime.h"
#include "vec/exprs/virtual_slot_ref.h"
#include "vector_search_utils.h"
@@ -36,7 +36,7 @@ using ::testing::Return;
namespace doris::vectorized {
-TEST_F(VectorSearchTest, AnnTopNDescriptorConstructor) {
+TEST_F(VectorSearchTest, AnnTopNRuntimeConstructor) {
int limit = 10;
std::shared_ptr<VExprContext> distanc_calcu_fn_call_ctx;
auto distance_function_call_thrift =
read_from_json<TExpr>(_distance_function_call_thrift);
@@ -59,17 +59,17 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorConstructor) {
std::shared_ptr<VirtualSlotRef> v =
std::dynamic_pointer_cast<VirtualSlotRef>(virtual_slot_expr_ctx->root());
if (v == nullptr) {
- LOG(FATAL) << "VAnnTopNDescriptor::SetUp() failed";
+ LOG(FATAL) << "VAnnTopNRuntime::SetUp() failed";
}
v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root());
- std::shared_ptr<AnnTopNDescriptor> predicate;
- predicate = AnnTopNDescriptor::create_shared(true, limit,
virtual_slot_expr_ctx);
- ASSERT_TRUE(predicate != nullptr) <<
"AnnTopNDescriptor::create_shared(true,) failed";
+ std::shared_ptr<AnnTopNRuntime> predicate;
+ predicate = AnnTopNRuntime::create_shared(true, limit,
virtual_slot_expr_ctx);
+ ASSERT_TRUE(predicate != nullptr) << "AnnTopNRuntime::create_shared(true,)
failed";
}
-TEST_F(VectorSearchTest, AnnTopNDescriptorPrepare) {
+TEST_F(VectorSearchTest, AnnTopNRuntimePrepare) {
int limit = 10;
std::shared_ptr<VExprContext> distanc_calcu_fn_call_ctx;
auto distance_function_call_thrift =
read_from_json<TExpr>(_distance_function_call_thrift);
@@ -81,12 +81,12 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorPrepare) {
std::shared_ptr<VirtualSlotRef> v =
std::dynamic_pointer_cast<VirtualSlotRef>(virtual_slot_expr_ctx->root());
if (v == nullptr) {
- LOG(FATAL) << "VAnnTopNDescriptor::SetUp() failed";
+ LOG(FATAL) << "VAnnTopNRuntime::SetUp() failed";
}
v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root());
- std::shared_ptr<AnnTopNDescriptor> predicate;
- predicate = AnnTopNDescriptor::create_shared(true, limit,
virtual_slot_expr_ctx);
+ std::shared_ptr<AnnTopNRuntime> predicate;
+ predicate = AnnTopNRuntime::create_shared(true, limit,
virtual_slot_expr_ctx);
st = predicate->prepare(&_runtime_state, _row_desc);
ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(),
predicate->get_order_by_expr_ctx()->root()->debug_string());
@@ -94,7 +94,7 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorPrepare) {
std::cout << "predicate: " << predicate->debug_string() << std::endl;
}
-TEST_F(VectorSearchTest, AnnTopNDescriptorEvaluateTopN) {
+TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) {
int limit = 10;
std::shared_ptr<VExprContext> distanc_calcu_fn_call_ctx;
auto distance_function_call_thrift =
read_from_json<TExpr>(_distance_function_call_thrift);
@@ -106,12 +106,12 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorEvaluateTopN) {
std::shared_ptr<VirtualSlotRef> v =
std::dynamic_pointer_cast<VirtualSlotRef>(virtual_slot_expr_ctx->root());
if (v == nullptr) {
- LOG(FATAL) << "VAnnTopNDescriptor::SetUp() failed";
+ LOG(FATAL) << "VAnnTopNRuntime::SetUp() failed";
}
v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root());
- std::shared_ptr<AnnTopNDescriptor> predicate;
- predicate = AnnTopNDescriptor::create_shared(true, limit,
virtual_slot_expr_ctx);
+ std::shared_ptr<AnnTopNRuntime> predicate;
+ predicate = AnnTopNRuntime::create_shared(true, limit,
virtual_slot_expr_ctx);
st = predicate->prepare(&_runtime_state, _row_desc);
ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(),
predicate->get_order_by_expr_ctx()->root()->debug_string());
diff --git a/be/test/olap/vector_search/faiss_vector_index_test.cpp
b/be/test/olap/vector_search/faiss_vector_index_test.cpp
index 684c3bf6cf4..7c70e174f0a 100644
--- a/be/test/olap/vector_search/faiss_vector_index_test.cpp
+++ b/be/test/olap/vector_search/faiss_vector_index_test.cpp
@@ -604,7 +604,7 @@ TEST_F(VectorSearchTest, RangeSearchEmptyResult) {
vector_search_utils::create_doris_index(vector_search_utils::IndexType::HNSW,
d, m);
std::vector<float> vectors;
- // Create 1000 vectors and make sure their l2_distance with
[1,2,3,4,5,6,7,8,9,10] is less than 100
+ // Create 1000 vectors and make sure their l2_distance_approximate
with [1,2,3,4,5,6,7,8,9,10] is less than 100
// [1,2,3,4,5,6,7,8,9,10]
// [2,3,4,5,6,7,8,9,10,1]
// [3,4,5,6,7,8,9,10,1,2]
diff --git a/be/test/olap/vector_search/vector_search_utils.h
b/be/test/olap/vector_search/vector_search_utils.h
index bd4b02ad0a7..a7e3d5fdeaf 100644
--- a/be/test/olap/vector_search/vector_search_utils.h
+++ b/be/test/olap/vector_search/vector_search_utils.h
@@ -264,7 +264,7 @@ private:
[0] TExprNode {
num_children = 2
fn = TFunctionName {
- name = "l2_distance"
+ name = "l2_distance_approximate"
}
},
[1] TExprNode {
@@ -325,6 +325,6 @@ private:
},
*/
const std::string _distance_function_call_thrift =
-
R"xxx({"1":{"lst":["rec",12,{"1":{"i32":20},"2":{"rec":{"1":{"lst":["rec",1,{"1":{"i32":0},"2":{"rec":{"1":{"i32":8}}}}]},"3":{"i64":-1}}},"4":{"i32":2},"20":{"i32":-1},"26":{"rec":{"1":{"rec":{"2":{"str":"l2_distance"}}},"2":{"i32":0},"3":{"lst":["rec",2,{"1":{"lst":["rec",2,{"1":{"i32":1},"4":{"tf":1},"5":{"lst":["tf",1,1]}},{"1":{"i32":0},"2":{"rec":{"1":{"i32":8}}}}]},"3":{"i64":-1}},{"1":{"lst":["rec",2,{"1":{"i32":1},"4":{"tf":1},"5":{"lst":["tf",1,1]}},{"1":{"i32":0},"
[...]
+
R"xxx({"1":{"lst":["rec",12,{"1":{"i32":20},"2":{"rec":{"1":{"lst":["rec",1,{"1":{"i32":0},"2":{"rec":{"1":{"i32":8}}}}]},"3":{"i64":-1}}},"4":{"i32":2},"20":{"i32":-1},"26":{"rec":{"1":{"rec":{"2":{"str":"l2_distance_approximate"}}},"2":{"i32":0},"3":{"lst":["rec",2,{"1":{"lst":["rec",2,{"1":{"i32":1},"4":{"tf":1},"5":{"lst":["tf",1,1]}},{"1":{"i32":0},"2":{"rec":{"1":{"i32":8}}}}]},"3":{"i64":-1}},{"1":{"lst":["rec",2,{"1":{"i32":1},"4":{"tf":1},"5":{"lst":["tf",1,1]}},{"1"
[...]
};
} // namespace doris::vectorized
\ No newline at end of file
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
index eb053c99531..24bf05c2ec5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
@@ -222,6 +222,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Ignore;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Initcap;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProduct;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProductApproximate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Instr;
import org.apache.doris.nereids.trees.expressions.functions.scalar.InttoUuid;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.Ipv4CIDRToRange;
@@ -277,6 +278,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbParseNul
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbType;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbValid;
import org.apache.doris.nereids.trees.expressions.functions.scalar.L1Distance;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.L2DistanceApproximate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.L2Distance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LastDay;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LastQueryId;
@@ -705,6 +707,7 @@ public class BuiltinScalarFunctions implements
FunctionHelper {
scalar(If.class, "if"),
scalar(Ignore.class, "ignore"),
scalar(Initcap.class, "initcap"),
+ scalar(InnerProductApproximate.class, "inner_product_approximate"),
scalar(InnerProduct.class, "inner_product"),
scalar(Instr.class, "instr"),
scalar(InttoUuid.class, "int_to_uuid"),
@@ -782,6 +785,7 @@ public class BuiltinScalarFunctions implements
FunctionHelper {
scalar(JsonContains.class, "json_contains"),
scalar(JsonKeys.class, "json_keys", "jsonb_keys"),
scalar(L1Distance.class, "l1_distance"),
+ scalar(L2DistanceApproximate.class, "l2_distance_approximate"),
scalar(L2Distance.class, "l2_distance"),
scalar(LastDay.class, "last_day"),
scalar(Least.class, "least"),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
index ac5dc07aab4..7573eb8ade2 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
@@ -28,7 +28,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
-import org.apache.doris.nereids.trees.expressions.functions.scalar.L2Distance;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.L2DistanceApproximate;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@@ -88,15 +88,15 @@ public class PushDownVectorTopNIntoOlapScan implements
RewriteRuleFactory {
if (orderKeyExpr == null) {
return null;
}
- if (!(orderKeyExpr instanceof L2Distance)) {
+ if (!(orderKeyExpr instanceof L2DistanceApproximate)) {
return null;
}
- L2Distance l2Distance = (L2Distance) orderKeyExpr;
- Expression left = l2Distance.left();
+ L2DistanceApproximate l2DistanceApproximate = (L2DistanceApproximate)
orderKeyExpr;
+ Expression left = l2DistanceApproximate.left();
while (left instanceof Cast) {
left = ((Cast) left).child();
}
- if (!(left instanceof SlotReference && ((L2Distance)
orderKeyExpr).right().isConstant())) {
+ if (!(left instanceof SlotReference && ((L2DistanceApproximate)
orderKeyExpr).right().isConstant())) {
return null;
}
SlotReference leftInput = (SlotReference) left;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java
index 1bfd71e920a..0d25d8ba508 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java
@@ -22,7 +22,8 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.functions.scalar.L2Distance;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProductApproximate;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.L2DistanceApproximate;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@@ -71,10 +72,8 @@ public class PushDownVirtualColumnsIntoOlapScan implements
RewriteRuleFactory {
Map<Expression, Expression> replaceMap = Maps.newHashMap();
ImmutableList.Builder<NamedExpression> virtualColumnsBuilder =
ImmutableList.builder();
for (Expression conjunct : filter.getConjuncts()) {
- // Set<Expression> l2Distances =
conjunct.collect(L2Distance.class::isInstance);
- // Set<Expression> innerProducts =
conjunct.collect(InnerProduct.class::isInstance);
Set<Expression> distanceFunctions = conjunct.collect(
- e -> e instanceof L2Distance || e instanceof InnerProduct);
+ e -> e instanceof L2DistanceApproximate || e instanceof
InnerProductApproximate);
for (Expression distanceFunction : distanceFunctions) {
if (replaceMap.containsKey(distanceFunction)) {
continue;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProductApproximate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProductApproximate.java
new file mode 100644
index 00000000000..bce8e038e78
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProductApproximate.java
@@ -0,0 +1,71 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.scalar;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
+import
org.apache.doris.nereids.trees.expressions.functions.ComputePrecisionForArrayItemAgg;
+import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.DoubleType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * inner_product function
+ */
+public class InnerProductApproximate extends ScalarFunction implements
ExplicitlyCastableSignature,
+ ComputePrecisionForArrayItemAgg, UnaryExpression, AlwaysNullable {
+
+ public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+ FunctionSignature.ret(DoubleType.INSTANCE)
+ .args(ArrayType.of(DoubleType.INSTANCE),
ArrayType.of(DoubleType.INSTANCE))
+ );
+
+ /**
+ * constructor with 1 argument.
+ */
+ public InnerProductApproximate(Expression arg0, Expression arg1) {
+ super("inner_product_approximate", arg0, arg1);
+ }
+
+ /**
+ * withChildren.
+ */
+ @Override
+ public InnerProductApproximate withChildren(List<Expression> children) {
+ Preconditions.checkArgument(children.size() == 2);
+ return new InnerProductApproximate(children.get(0), children.get(1));
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitInnerProductApproximate(this, context);
+ }
+
+ @Override
+ public List<FunctionSignature> getSignatures() {
+ return SIGNATURES;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2DistanceApproximate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2DistanceApproximate.java
new file mode 100644
index 00000000000..dff59b15580
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2DistanceApproximate.java
@@ -0,0 +1,71 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.scalar;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
+import
org.apache.doris.nereids.trees.expressions.functions.ComputePrecisionForArrayItemAgg;
+import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.DoubleType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * l2_distance_approximate function
+ */
+public class L2DistanceApproximate extends ScalarFunction implements
ExplicitlyCastableSignature,
+ ComputePrecisionForArrayItemAgg, BinaryExpression, AlwaysNullable {
+
+ public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+ FunctionSignature.ret(DoubleType.INSTANCE)
+ .args(ArrayType.of(DoubleType.INSTANCE),
ArrayType.of(DoubleType.INSTANCE))
+ );
+
+ /**
+ * constructor with 1 argument.
+ */
+ public L2DistanceApproximate(Expression arg0, Expression arg1) {
+ super("l2_distance_approximate", arg0, arg1);
+ }
+
+ /**
+ * withChildren.
+ */
+ @Override
+ public L2DistanceApproximate withChildren(List<Expression> children) {
+ Preconditions.checkArgument(children.size() == 2);
+ return new L2DistanceApproximate(children.get(0), children.get(1));
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitL2DistanceApproximate(this, context);
+ }
+
+ @Override
+ public List<FunctionSignature> getSignatures() {
+ return SIGNATURES;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
index df18f9d7e62..b0d3694f5ea 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
@@ -225,6 +225,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Ignore;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Initcap;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProductApproximate;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProduct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Instr;
import org.apache.doris.nereids.trees.expressions.functions.scalar.InttoUuid;
@@ -281,6 +282,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbParseNul
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbType;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbValid;
import org.apache.doris.nereids.trees.expressions.functions.scalar.L1Distance;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.L2DistanceApproximate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.L2Distance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LastDay;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LastQueryId;
@@ -1351,6 +1353,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(innerProduct, context);
}
+ default R visitInnerProductApproximate(InnerProductApproximate
innerProductApproximate, C context) {
+ return visitScalarFunction(innerProductApproximate, context);
+ }
+
default R visitInstr(Instr instr, C context) {
return visitScalarFunction(instr, context);
}
@@ -1571,6 +1577,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(l2Distance, context);
}
+ default R visitL2DistanceApproximate(L2DistanceApproximate
l2DistanceApproximate, C context) {
+ return visitScalarFunction(l2DistanceApproximate, context);
+ }
+
default R visitLastDay(LastDay lastDay, C context) {
return visitScalarFunction(lastDay, context);
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]