This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch vector-index-dev
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/vector-index-dev by this push:
new 1b671c9b06b [vector search] Step forward on stability and
functionality (#51213)
1b671c9b06b is described below
commit 1b671c9b06bf965fd01a152340202ffd85890975
Author: zhiqiang <[email protected]>
AuthorDate: Tue May 27 11:38:16 2025 +0800
[vector search] Step forward on stability and functionality (#51213)
A huge step forward on stability and functionality.
### Functionality
1. Search parameters like `ef_search`, can be passed to index as session
variables. This behavior is same with pg-vector and duckdb vector search
plug-in.
2. Correct processing for order by desc. Fallback to brute force search
when it is necessary.
3. Support using inner product as index metric and order by
inner_product.
4. When metrics of sql dismatches with index, fallback to brute force.
### Stability
1. More unit test
2. Virtual column iterator.
3. According to custom script, result of range search, topn search &
compound search is almost same with native faiss. The overlap rate of
result is more than 90%. The 10% difference is introduced by batch
insert mode of native faiss.
---
be/src/olap/rowset/beta_rowset_reader.cpp | 7 +
.../olap/rowset/segment_v2/ann_index_iterator.cpp | 2 +-
be/src/olap/rowset/segment_v2/ann_index_iterator.h | 8 +-
be/src/olap/rowset/segment_v2/ann_index_reader.cpp | 31 ++-
be/src/olap/rowset/segment_v2/ann_index_reader.h | 9 +-
be/src/olap/rowset/segment_v2/ann_index_writer.cpp | 32 +--
be/src/olap/rowset/segment_v2/ann_index_writer.h | 1 +
be/src/olap/rowset/segment_v2/segment_iterator.cpp | 234 +++++++++++++++------
be/src/olap/rowset/segment_v2/segment_iterator.h | 9 +-
.../rowset/segment_v2/virtual_column_iterator.cpp | 44 +++-
.../rowset/segment_v2/virtual_column_iterator.h | 7 +-
be/src/pipeline/exec/olap_scan_operator.cpp | 6 +-
be/src/pipeline/exec/operator.cpp | 8 +-
be/src/runtime/descriptors.cpp | 20 ++
be/src/runtime/runtime_state.h | 8 +-
be/src/vec/core/block.cpp | 1 +
be/src/vec/exec/scan/olap_scanner.cpp | 4 +-
be/src/vec/exec/scan/olap_scanner.h | 3 +
be/src/vec/exprs/ann_range_search_params.h | 21 +-
be/src/vec/exprs/vann_topn_predicate.cpp | 17 +-
be/src/vec/exprs/vann_topn_predicate.h | 11 +-
be/src/vec/exprs/vectorized_fn_call.cpp | 42 +++-
be/src/vec/exprs/vectorized_fn_call.h | 6 +-
be/src/vec/exprs/vexpr.cpp | 4 +-
be/src/vec/exprs/vexpr.h | 2 +-
be/src/vec/exprs/vexpr_context.cpp | 4 +-
be/src/vec/exprs/vexpr_context.h | 3 +-
be/src/vec/exprs/virtual_slot_ref.cpp | 2 +-
.../vec/functions/array/function_array_distance.h | 3 +
be/src/vec/runtime/vector_search_user_params.cpp | 35 +++
be/src/vec/runtime/vector_search_user_params.h | 31 +++
be/src/vector/faiss_vector_index.cpp | 54 ++++-
be/src/vector/faiss_vector_index.h | 24 ++-
be/src/vector/vector_index.h | 19 +-
.../olap/vector_search/ann_index_reader_test.cpp | 98 +++++++--
.../olap/vector_search/ann_range_search_test.cpp | 39 ++--
.../vector_search/ann_topn_descriptor_test.cpp | 8 +-
.../olap/vector_search/faiss_vector_index_test.cpp | 26 +--
be/test/olap/vector_search/vector_search_utils.cpp | 11 +
be/test/olap/vector_search/vector_search_utils.h | 37 +---
.../vector_search/virtual_column_iterator_test.cpp | 76 ++++++-
.../PushDownVirtualColumnsIntoOlapScan.java | 19 +-
.../trees/plans/commands/info/IndexDefinition.java | 21 +-
.../java/org/apache/doris/qe/SessionVariable.java | 25 +++
gensrc/thrift/PaloInternalService.thrift | 4 +
45 files changed, 807 insertions(+), 269 deletions(-)
diff --git a/be/src/olap/rowset/beta_rowset_reader.cpp
b/be/src/olap/rowset/beta_rowset_reader.cpp
index 66a44e7864e..e12c89d056f 100644
--- a/be/src/olap/rowset/beta_rowset_reader.cpp
+++ b/be/src/olap/rowset/beta_rowset_reader.cpp
@@ -146,6 +146,9 @@ Status
BetaRowsetReader::get_segment_iterators(RowsetReaderContext* read_context
_read_options.column_predicates.insert(_read_options.column_predicates.end(),
_read_context->predicates->begin(),
_read_context->predicates->end());
+ LOG_INFO("Rowset reader, read options column predicates size: {}",
+ _read_options.column_predicates.size());
+
for (auto pred : *(_read_context->predicates)) {
if (_read_options.col_id_to_predicates.count(pred->column_id()) <
1) {
_read_options.col_id_to_predicates.insert(
@@ -185,6 +188,10 @@ Status
BetaRowsetReader::get_segment_iterators(RowsetReaderContext* read_context
_read_options.column_predicates.insert(_read_options.column_predicates.end(),
_read_context->value_predicates->begin(),
_read_context->value_predicates->end());
+ LOG_INFO(
+ "Rowset reader, read options add value predicates, column
predicates size now: "
+ "{}",
+ _read_options.column_predicates.size());
for (auto pred : *(_read_context->value_predicates)) {
if
(_read_options.col_id_to_predicates.count(pred->column_id()) < 1) {
_read_options.col_id_to_predicates.insert(
diff --git a/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp
b/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp
index 6a50032e2fb..3b37e3cabcb 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp
@@ -37,7 +37,7 @@ Status AnnIndexIterator::read_from_index(const IndexParam&
param) {
}
Status AnnIndexIterator::range_search(const RangeSearchParams& params,
- const CustomSearchParams& custom_params,
+ const VectorSearchUserParams&
custom_params,
RangeSearchResult* result) {
if (_ann_reader == nullptr) {
return Status::Error<ErrorCode::INDEX_INVALID_PARAMETERS>("_ann_reader
is null");
diff --git a/be/src/olap/rowset/segment_v2/ann_index_iterator.h
b/be/src/olap/rowset/segment_v2/ann_index_iterator.h
index 0972c69307e..82a4113cacb 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_iterator.h
+++ b/be/src/olap/rowset/segment_v2/ann_index_iterator.h
@@ -23,6 +23,7 @@
#include "gutil/integral_types.h"
#include "olap/rowset/segment_v2/ann_index_reader.h"
#include "olap/rowset/segment_v2/index_iterator.h"
+#include "runtime/runtime_state.h"
namespace doris::segment_v2 {
@@ -30,6 +31,7 @@ struct AnnIndexParam {
const float* query_value;
const size_t query_value_size;
size_t limit;
+ doris::VectorSearchUserParams _user_params;
roaring::Roaring* roaring;
std::unique_ptr<std::vector<float>> distance = nullptr;
std::unique_ptr<std::vector<uint64_t>> row_ids = nullptr;
@@ -48,10 +50,6 @@ struct RangeSearchParams {
virtual ~RangeSearchParams() = default;
};
-struct CustomSearchParams {
- int ef_search = 16;
-};
-
struct RangeSearchResult {
std::shared_ptr<roaring::Roaring> roaring;
std::unique_ptr<std::vector<uint64_t>> row_ids;
@@ -80,7 +78,7 @@ public:
bool has_null() override { return true; }
MOCK_FUNCTION Status range_search(const RangeSearchParams& params,
- const CustomSearchParams& custom_params,
+ const VectorSearchUserParams&
custom_params,
RangeSearchResult* result);
private:
diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
index 0222597ea32..64637a72566 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
@@ -24,6 +24,7 @@
#include "common/config.h"
#include "olap/rowset/segment_v2/index_file_reader.h"
#include "olap/rowset/segment_v2/inverted_index_compound_reader.h"
+#include "runtime/runtime_state.h"
#include "vector/faiss_vector_index.h"
#include "vector/vector_index.h"
@@ -49,6 +50,9 @@ AnnIndexReader::AnnIndexReader(const TabletIndex* index_meta,
auto it = index_properties.find("index_type");
DCHECK(it != index_properties.end());
_index_type = it->second;
+ it = index_properties.find("metric_type");
+ DCHECK(it != index_properties.end());
+ _metric_type = VectorIndex::string_to_metric(it->second);
}
Status AnnIndexReader::new_iterator(const io::IOContext& io_ctx,
OlapReaderStatistics* stats,
@@ -71,16 +75,27 @@ Status AnnIndexReader::load_index(io::IOContext* io_ctx) {
}
Status AnnIndexReader::query(io::IOContext* io_ctx, AnnIndexParam* param) {
+#ifndef BE_TEST
RETURN_IF_ERROR(_index_file_reader->init(config::inverted_index_read_buffer_size,
io_ctx));
RETURN_IF_ERROR(load_index(io_ctx));
+#endif
DCHECK(_vector_index != nullptr);
const float* query_vec = param->query_value;
const int limit = param->limit;
- IndexSearchParameters index_search_params;
IndexSearchResult index_search_result;
- index_search_params.roaring = param->roaring;
- RETURN_IF_ERROR(_vector_index->ann_topn_search(query_vec, limit,
index_search_params,
- index_search_result));
+ if (_index_type == "hnsw") {
+ HNSWSearchParameters hnsw_search_params;
+ hnsw_search_params.roaring = param->roaring;
+ hnsw_search_params.ef_search = param->_user_params.hnsw_ef_search;
+ hnsw_search_params.check_relative_distance =
+ param->_user_params.hnsw_check_relative_distance;
+ hnsw_search_params.bounded_queue =
param->_user_params.hnsw_bounded_queue;
+ RETURN_IF_ERROR(_vector_index->ann_topn_search(query_vec, limit,
hnsw_search_params,
+ index_search_result));
+ } else {
+ throw Status::NotSupported("Unsupported index type: {}", _index_type);
+ }
+
DCHECK(index_search_result.roaring != nullptr);
DCHECK(index_search_result.distances != nullptr);
DCHECK(index_search_result.row_ids != nullptr);
@@ -92,17 +107,21 @@ Status AnnIndexReader::query(io::IOContext* io_ctx,
AnnIndexParam* param) {
}
Status AnnIndexReader::range_search(const RangeSearchParams& params,
- const CustomSearchParams& custom_params,
+ const VectorSearchUserParams&
custom_params,
RangeSearchResult* result, io::IOContext*
io_ctx) {
+#ifndef BE_TEST
RETURN_IF_ERROR(_index_file_reader->init(config::inverted_index_read_buffer_size,
io_ctx));
RETURN_IF_ERROR(load_index(io_ctx));
+#endif
DCHECK(_vector_index != nullptr);
IndexSearchResult search_result;
std::unique_ptr<IndexSearchParameters> search_param = nullptr;
if (_index_type == "hnsw") {
auto hnsw_param = std::make_unique<HNSWSearchParameters>();
- hnsw_param->ef_search = custom_params.ef_search;
+ hnsw_param->ef_search = custom_params.hnsw_ef_search;
+ hnsw_param->check_relative_distance =
custom_params.hnsw_check_relative_distance;
+ hnsw_param->bounded_queue = custom_params.hnsw_bounded_queue;
search_param = std::move(hnsw_param);
} else {
throw Status::NotSupported("Unsupported index type: {}", _index_type);
diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.h
b/be/src/olap/rowset/segment_v2/ann_index_reader.h
index 69ea9c91c0b..a12c0a508e4 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_reader.h
+++ b/be/src/olap/rowset/segment_v2/ann_index_reader.h
@@ -19,13 +19,13 @@
#include "olap/rowset/segment_v2/index_reader.h"
#include "olap/tablet_schema.h"
+#include "runtime/runtime_state.h"
#include "vector/vector_index.h"
namespace doris::segment_v2 {
struct AnnIndexParam;
struct RangeSearchParams;
-struct CustomSearchParams;
struct RangeSearchResult;
class IndexFileReader;
@@ -44,14 +44,16 @@ public:
Status query(io::IOContext* io_ctx, AnnIndexParam* param);
- Status range_search(const RangeSearchParams& params, const
CustomSearchParams& custom_params,
- RangeSearchResult* result, io::IOContext* io_ctx =
nullptr);
+ Status range_search(const RangeSearchParams& params,
+ const VectorSearchUserParams& custom_params,
RangeSearchResult* result,
+ io::IOContext* io_ctx = nullptr);
uint64_t get_index_id() const override { return _index_meta.index_id(); }
Status new_iterator(const io::IOContext& io_ctx, OlapReaderStatistics*
stats,
RuntimeState* runtime_state,
std::unique_ptr<IndexIterator>* iterator) override;
+ VectorIndex::Metric get_metric_type() const { return _metric_type; }
private:
TabletIndex _index_meta;
@@ -59,6 +61,7 @@ private:
std::unique_ptr<VectorIndex> _vector_index;
// TODO: Use integer.
std::string _index_type;
+ VectorIndex::Metric _metric_type;
};
using AnnIndexReaderPtr = std::shared_ptr<AnnIndexReader>;
diff --git a/be/src/olap/rowset/segment_v2/ann_index_writer.cpp
b/be/src/olap/rowset/segment_v2/ann_index_writer.cpp
index 452d9dfc9a2..0f8c6cd1d7c 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index_writer.cpp
@@ -19,6 +19,7 @@
#include <cstddef>
#include <memory>
+#include <string>
#include "olap/rowset/segment_v2/inverted_index_fs_directory.h"
@@ -57,21 +58,22 @@ Status AnnIndexColumnWriter::init() {
_vector_index = nullptr;
const auto& properties = _index_meta->properties();
- std::string index_type = get_or_default(properties, INDEX_TYPE, "");
- if (index_type == "hnsw") {
- std::shared_ptr<FaissVectorIndex> faiss_index =
std::make_shared<FaissVectorIndex>();
- FaissBuildParameter builderParameter;
- builderParameter.index_type =
FaissBuildParameter::string_to_index_type("hnsw");
- builderParameter.d = std::stoi(get_or_default(properties, DIM, "512"));
- builderParameter.m = std::stoi(get_or_default(properties, MAX_DEGREE,
"32"));
- builderParameter.quantilizer =
FaissBuildParameter::string_to_quantilizer(
- get_or_default(properties, QUANTILIZER, "flat"));
- faiss_index->set_build_params(builderParameter);
- _vector_index = faiss_index;
- } else {
- return Status::NotSupported("Unsupported index type: " + index_type);
- }
-
+ const std::string index_type = get_or_default(properties, INDEX_TYPE,
"hnsw");
+ const std::string metric_type = get_or_default(properties, METRIC_TYPE,
"l2");
+ const std::string quantilizer = get_or_default(properties, QUANTILIZER,
"flat");
+ FaissBuildParameter builderParameter;
+ std::shared_ptr<FaissVectorIndex> faiss_index =
std::make_shared<FaissVectorIndex>();
+ builderParameter.index_type =
FaissBuildParameter::string_to_index_type(index_type);
+ builderParameter.d = std::stoi(get_or_default(properties, DIM, "512"));
+ builderParameter.m = std::stoi(get_or_default(properties, MAX_DEGREE,
"32"));
+ builderParameter.pq_m = std::stoi(get_or_default(properties, PQ_M, "-1"));
// -1 means not set
+
+ builderParameter.metric_type =
FaissBuildParameter::string_to_metric_type(metric_type);
+ builderParameter.quantilizer =
FaissBuildParameter::string_to_quantilizer(quantilizer);
+
+ faiss_index->set_build_params(builderParameter);
+
+ _vector_index = faiss_index;
return Status::OK();
}
diff --git a/be/src/olap/rowset/segment_v2/ann_index_writer.h
b/be/src/olap/rowset/segment_v2/ann_index_writer.h
index cb8f9316fc9..d674fb12648 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_writer.h
+++ b/be/src/olap/rowset/segment_v2/ann_index_writer.h
@@ -50,6 +50,7 @@ public:
static constexpr const char* INDEX_TYPE = "index_type";
static constexpr const char* METRIC_TYPE = "metric_type";
static constexpr const char* QUANTILIZER = "quantilizer";
+ static constexpr const char* PQ_M = "pq_m";
static constexpr const char* DIM = "dim";
static constexpr const char* MAX_DEGREE = "max_degree";
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
index 00a15be3e23..3075af6122c 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
@@ -24,6 +24,7 @@
#include <algorithm>
#include <boost/iterator/iterator_facade.hpp>
+#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
@@ -98,6 +99,7 @@
#include "vec/exprs/vslot_ref.h"
#include "vec/functions/array/function_array_index.h"
#include "vec/json/path_in_data.h"
+#include "vector/vector_index.h"
namespace doris {
using namespace ErrorCode;
@@ -299,6 +301,7 @@ Status SegmentIterator::_init_impl(const
StorageReadOptions& opts) {
}
_col_predicates.emplace_back(predicate);
}
+ LOG_INFO("Segment iterator init, column predicates size: {}",
_col_predicates.size());
_tablet_id = opts.tablet_id;
// Read options will not change, so that just resize here
_block_rowids.resize(_opts.block_row_max);
@@ -632,6 +635,32 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
!_col_predicates.empty());
return Status::OK();
}
+
+ // Process asc & desc according to the type of metric
+ auto index_reader = ann_index_iterator->get_reader();
+ auto ann_index_reader = dynamic_cast<AnnIndexReader*>(index_reader.get());
+ DCHECK(ann_index_reader != nullptr);
+ if (ann_index_reader->get_metric_type() ==
VectorIndex::Metric::INNER_PRODUCT) {
+ if (_ann_topn_descriptor->is_asc()) {
+ LOG_INFO("asc topn for inner product can not be evaluated by ann
index");
+ return Status::OK();
+ }
+ } else {
+ if (!_ann_topn_descriptor->is_asc()) {
+ LOG_INFO("desc topn for l2/cosine can not be evaluated by ann
index");
+ return Status::OK();
+ }
+ }
+
+ if (ann_index_reader->get_metric_type() !=
_ann_topn_descriptor->get_metric_type()) {
+ LOG_INFO(
+ "Ann topn metric type {} not match index metric type {}, can
not be evaluated by "
+ "ann index",
+
VectorIndex::metric_to_string(_ann_topn_descriptor->get_metric_type()),
+
VectorIndex::metric_to_string(ann_index_reader->get_metric_type()));
+ return Status::OK();
+ }
+
size_t pre_size = _row_bitmap.cardinality();
size_t dst_col_idx = _ann_topn_descriptor->get_dest_column_idx();
vectorized::IColumn::MutablePtr result_column;
@@ -647,6 +676,8 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
DCHECK(column_iter != nullptr);
VirtualColumnIterator* virtual_column_iter =
dynamic_cast<VirtualColumnIterator*>(column_iter);
DCHECK(virtual_column_iter != nullptr);
+ LOG_INFO("Virtual column iterator, column_idx {}, is materialized with {}
rows", dst_col_idx,
+ result_row_ids->size());
virtual_column_iter->prepare_materialization(std::move(result_column),
std::move(result_row_ids));
return Status::OK();
@@ -936,6 +967,7 @@ Status SegmentIterator::_apply_index_expr() {
++it;
}
}
+ // TODO:remove expr root from _remaining_conjunct_roots
return Status::OK();
}
@@ -1040,6 +1072,10 @@ bool SegmentIterator::_need_read_data(ColumnId cid) {
_opts.enable_unique_key_merge_on_write)))) {
return true;
}
+ if (this->_vir_cid_to_idx_in_block.contains(cid)) {
+ return true;
+ }
+
// if there is delete predicate, we always need to read data
if (_has_delete_predicate(cid)) {
return true;
@@ -1450,7 +1486,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() {
_is_pred_column.resize(_schema->columns().size(), false);
// including short/vec/delete pred
- std::set<ColumnId> pred_column_ids;
+ std::set<ColumnId> cols_read_by_column_predicate;
_lazy_materialization_read = false;
std::set<ColumnId> del_cond_id_set;
@@ -1490,7 +1526,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() {
for (auto* predicate : _col_predicates) {
auto cid = predicate->column_id();
_is_pred_column[cid] = true;
- pred_column_ids.insert(cid);
+ cols_read_by_column_predicate.insert(cid);
// check pred using short eval or vec eval
if (_can_evaluated_by_vectorized(predicate)) {
@@ -1508,7 +1544,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() {
// handle delete_condition
if (!del_cond_id_set.empty()) {
short_cir_pred_col_id_set.insert(del_cond_id_set.begin(),
del_cond_id_set.end());
- pred_column_ids.insert(del_cond_id_set.begin(),
del_cond_id_set.end());
+ cols_read_by_column_predicate.insert(del_cond_id_set.begin(),
del_cond_id_set.end());
for (auto cid : del_cond_id_set) {
_is_pred_column[cid] = true;
@@ -1566,7 +1602,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() {
// all columns are lazy materialization columns without non predicte
column.
// If common expr pushdown exists, and expr column is not contained in
lazy materialization columns,
// add to second read column, which will be read after lazy materialization
- if (_schema->column_ids().size() > pred_column_ids.size()) {
+ if (_schema->column_ids().size() > cols_read_by_column_predicate.size()) {
// pred_column_ids maybe empty, so that could not set
_lazy_materialization_read = true here
// has to check there is at least one predicate column
for (auto cid : _schema->column_ids()) {
@@ -1574,10 +1610,10 @@ Status
SegmentIterator::_vec_init_lazy_materialization() {
if (_is_need_vec_eval || _is_need_short_eval) {
_lazy_materialization_read = true;
}
- if (!_is_common_expr_column[cid]) {
- _non_predicate_columns.push_back(cid);
+ if (_is_common_expr_column[cid]) {
+ _cols_read_by_common_expr.push_back(cid);
} else {
- _non_predicate_column_ids.push_back(cid);
+ _cols_not_included_by_any_predicates.push_back(cid);
}
}
}
@@ -1586,16 +1622,20 @@ Status
SegmentIterator::_vec_init_lazy_materialization() {
// Step 4: fill first read columns
if (_lazy_materialization_read) {
// insert pred cid to first_read_columns
- for (auto cid : pred_column_ids) {
- _predicate_column_ids.push_back(cid);
+ for (auto cid : cols_read_by_column_predicate) {
+ _cols_read_by_column_predicate.push_back(cid);
}
- } else if (!_is_need_vec_eval && !_is_need_short_eval &&
- !_is_need_expr_eval) { // no pred exists, just read and output
column
+ } else if (!_is_need_vec_eval && !_is_need_short_eval &&
!_is_need_expr_eval) {
+ // no pred exists, just read and output column
+ // 这代码也很迷惑啊,既然没有任何谓词列,那就不要改变流程啊,就按照正常的输出 non-predicates-columns 就好了啊
+ // 为什么要强行把所有的列当作 predicate 列去处理呢
for (int i = 0; i < _schema->num_column_ids(); i++) {
auto cid = _schema->column_id(i);
- _predicate_column_ids.push_back(cid);
+ _cols_read_by_column_predicate.push_back(cid);
}
} else {
+ // 不延迟物化,但是有谓词
+ // 说明除了 column_predicates 的列之外,还有其他列需要读
if (_is_need_vec_eval || _is_need_short_eval) {
// TODO To refactor, because we suppose lazy materialization is
better performance.
// pred exits, but we can eliminate lazy materialization
@@ -1605,12 +1645,12 @@ Status
SegmentIterator::_vec_init_lazy_materialization() {
_short_cir_pred_column_ids.end());
pred_id_set.insert(_vec_pred_column_ids.begin(),
_vec_pred_column_ids.end());
- DCHECK(_non_predicate_column_ids.empty());
+ DCHECK(_cols_read_by_common_expr.empty());
// _non_predicate_column_ids must be empty. Otherwise
_lazy_materialization_read must not false.
for (int i = 0; i < _schema->num_column_ids(); i++) {
auto cid = _schema->column_id(i);
if (pred_id_set.find(cid) != pred_id_set.end()) {
- _predicate_column_ids.push_back(cid);
+ _cols_read_by_column_predicate.push_back(cid);
}
// In the past, if schema columns > pred columns, the
_lazy_materialization_read maybe == false, but
// we make sure using _lazy_materialization_read= true now, so
these logic may never happens. I comment
@@ -1624,21 +1664,22 @@ Status
SegmentIterator::_vec_init_lazy_materialization() {
} else if (_is_need_expr_eval) {
DCHECK(!_is_need_vec_eval && !_is_need_short_eval);
for (auto cid : _common_expr_columns) {
- _predicate_column_ids.push_back(cid);
+ // 这代码太 track 了,很迷糊啊,完全概念混到一起了
+ _cols_read_by_column_predicate.push_back(cid);
}
}
}
LOG_INFO(
- "Laze materialization end. "
+ "Laze materialization init end. "
"lazy_materialization_read: {}, "
- "predicate_column_ids: [{}], "
- "non_predicate_columns: [{}], "
- "non_predicate_column_ids: [{}], "
+ "_cols_read_by_column_predicate: [{}], "
+ "_cols_not_included_by_any_predicates: [{}], "
+ "_cols_read_by_common_expr: [{}], "
"columns_to_filter: [{}]",
- _lazy_materialization_read, fmt::join(_predicate_column_ids, ","),
- fmt::join(_non_predicate_columns, ","),
fmt::join(_non_predicate_column_ids, ","),
- fmt::join(_columns_to_filter, ","));
+ _lazy_materialization_read,
fmt::join(_cols_read_by_column_predicate, ","),
+ fmt::join(_cols_not_included_by_any_predicates, ","),
+ fmt::join(_cols_read_by_common_expr, ","),
fmt::join(_columns_to_filter, ","));
return Status::OK();
}
@@ -1806,7 +1847,7 @@ Status
SegmentIterator::_init_return_columns(vectorized::Block* block, uint32_t
void SegmentIterator::_output_non_pred_columns(vectorized::Block* block) {
SCOPED_RAW_TIMER(&_opts.stats->output_col_ns);
- for (auto cid : _non_predicate_columns) {
+ for (auto cid : _cols_not_included_by_any_predicates) {
auto loc = _schema_block_id_map[cid];
// if loc > block->columns() means the column is delete column and
should
// not output by block, so just skip the column.
@@ -1841,31 +1882,46 @@ Status SegmentIterator::_read_columns_by_index(uint32_t
nrows_read_limit, uint32
SCOPED_RAW_TIMER(&_opts.stats->predicate_column_read_ns);
nrows_read = _range_iter->read_batch_rowids(_block_rowids.data(),
nrows_read_limit);
+ LOG_INFO("nrows_read from range iterator: {}", nrows_read);
bool is_continuous = (nrows_read > 1) &&
(_block_rowids[nrows_read - 1] - _block_rowids[0] ==
nrows_read - 1);
+ std::vector<ColumnId> predicate_column_ids_and_virtual_columns;
+
predicate_column_ids_and_virtual_columns.reserve(_cols_read_by_column_predicate.size()
+
+
_virtual_column_exprs.size());
+
predicate_column_ids_and_virtual_columns.insert(predicate_column_ids_and_virtual_columns.end(),
+
_cols_read_by_column_predicate.begin(),
+
_cols_read_by_column_predicate.end());
- for (auto cid : _predicate_column_ids) {
- auto& column = _current_return_columns[cid];
- if (_no_need_read_key_data(cid, column, nrows_read)) {
- continue;
- }
- if (_prune_column(cid, column, true, nrows_read)) {
- continue;
- }
+ for (const auto& entry : _virtual_column_exprs) {
+ // virtual column id is not in _predicate_column_ids
+ predicate_column_ids_and_virtual_columns.push_back(entry.first);
+ }
- DBUG_EXECUTE_IF("segment_iterator._read_columns_by_index", {
- auto col_name = _opts.tablet_schema->column(cid).name();
- auto debug_col_name =
DebugPoints::instance()->get_debug_param_or_default<std::string>(
- "segment_iterator._read_columns_by_index", "column_name",
"");
- if (debug_col_name.empty() && col_name != "__DORIS_DELETE_SIGN__")
{
- return Status::Error<ErrorCode::INTERNAL_ERROR>("does not need
to read data, {}",
- col_name);
+ for (auto cid : predicate_column_ids_and_virtual_columns) {
+ auto& column = _current_return_columns[cid];
+ if (!_virtual_column_exprs.contains(cid)) {
+ if (_no_need_read_key_data(cid, column, nrows_read)) {
+ continue;
}
- if (debug_col_name.find(col_name) != std::string::npos) {
- return Status::Error<ErrorCode::INTERNAL_ERROR>("does not need
to read data, {}",
- col_name);
+ if (_prune_column(cid, column, true, nrows_read)) {
+ continue;
}
- })
+
+ DBUG_EXECUTE_IF("segment_iterator._read_columns_by_index", {
+ auto col_name = _opts.tablet_schema->column(cid).name();
+ auto debug_col_name =
+
DebugPoints::instance()->get_debug_param_or_default<std::string>(
+ "segment_iterator._read_columns_by_index",
"column_name", "");
+ if (debug_col_name.empty() && col_name !=
"__DORIS_DELETE_SIGN__") {
+ return Status::Error<ErrorCode::INTERNAL_ERROR>(
+ "does not need to read data, {}", col_name);
+ }
+ if (debug_col_name.find(col_name) != std::string::npos) {
+ return Status::Error<ErrorCode::INTERNAL_ERROR>(
+ "does not need to read data, {}", col_name);
+ }
+ })
+ }
if (is_continuous) {
size_t rows_read = nrows_read;
@@ -2321,8 +2377,26 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
_current_batch_rows_read = 0;
RETURN_IF_ERROR(_read_columns_by_index(nrows_read_limit,
_current_batch_rows_read));
- if (std::find(_predicate_column_ids.begin(), _predicate_column_ids.end(),
- _schema->version_col_idx()) != _predicate_column_ids.end()) {
+
+ // 把从索引物化得到的虚拟列放到 block 中
+ for (const auto pair : _vir_cid_to_idx_in_block) {
+ ColumnId cid = pair.first;
+ size_t position = pair.second;
+ block->replace_by_position(position,
std::move(_current_return_columns[cid]));
+ bool is_nothing = false;
+ if (vectorized::check_and_get_column<vectorized::ColumnNothing>(
+ block->get_by_position(position).column.get())) {
+ is_nothing = true;
+ }
+
+ LOG_INFO(
+ "SegmentIterator next block replace virtual column, cid {},
position {}, still "
+ "nothing {}",
+ cid, position, is_nothing);
+ }
+
+ if (std::find(_cols_read_by_column_predicate.begin(),
_cols_read_by_column_predicate.end(),
+ _schema->version_col_idx()) !=
_cols_read_by_column_predicate.end()) {
_replace_version_col(_current_batch_rows_read);
}
@@ -2333,18 +2407,19 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
// Convert all columns in _current_return_columns to schema column
RETURN_IF_ERROR(_convert_to_expected_type(_schema->column_ids()));
for (int i = 0; i < block->columns() -
_vir_cid_to_idx_in_block.size(); i++) {
- // TODO: 虚拟列是否需要处理
auto cid = _schema->column_id(i);
// todo(wb) abstract make column where
if (!_is_pred_column[cid]) {
block->replace_by_position(i,
std::move(_current_return_columns[cid]));
}
}
+
for (auto& pair : _vir_cid_to_idx_in_block) {
auto cid = pair.first;
auto loc = pair.second;
block->replace_by_position(loc,
std::move(_current_return_columns[cid]));
}
+
block->clear_column_data();
// clear and release iterators memory footprint in advance
_clear_iterators();
@@ -2352,11 +2427,15 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
}
if (!_is_need_vec_eval && !_is_need_short_eval && !_is_need_expr_eval) {
- if (_non_predicate_columns.empty()) {
+ if (_cols_not_included_by_any_predicates.empty()) {
return Status::InternalError("_non_predicate_columns is empty");
}
- RETURN_IF_ERROR(_convert_to_expected_type(_predicate_column_ids));
- RETURN_IF_ERROR(_convert_to_expected_type(_non_predicate_columns));
+
RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_column_predicate));
+
RETURN_IF_ERROR(_convert_to_expected_type(_cols_not_included_by_any_predicates));
+ LOG_INFO(
+ "No need to evaluate any predicates or filter, output
non-predicate columns, "
+ "block rows {}, selected size {}",
+ block->rows(), _current_batch_rows_read);
_output_non_pred_columns(block);
} else {
uint16_t selected_size = _current_batch_rows_read;
@@ -2379,33 +2458,41 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
// when lazy materialization enables, _predicate_column_ids =
distinct(_short_cir_pred_column_ids + _vec_pred_column_ids)
// see _vec_init_lazy_materialization
// todo(wb) need to tell input columnids from output columnids
- RETURN_IF_ERROR(_output_column_by_sel_idx(block,
_predicate_column_ids,
+ RETURN_IF_ERROR(_output_column_by_sel_idx(block,
_cols_read_by_column_predicate,
_sel_rowid_idx.data(), selected_size));
// step 3.2: read remaining expr column and evaluate it.
if (_is_need_expr_eval) {
// The predicate column contains the remaining expr
column, no need second read.
- if (!_non_predicate_column_ids.empty()) {
+ if (_cols_read_by_common_expr.size() > 0) {
SCOPED_RAW_TIMER(&_opts.stats->non_predicate_read_ns);
RETURN_IF_ERROR(_read_columns_by_rowids(
- _non_predicate_column_ids, _block_rowids,
_sel_rowid_idx.data(),
+ _cols_read_by_common_expr, _block_rowids,
_sel_rowid_idx.data(),
selected_size, &_current_return_columns));
- if (std::find(_non_predicate_column_ids.begin(),
- _non_predicate_column_ids.end(),
+ if (std::find(_cols_read_by_common_expr.begin(),
+ _cols_read_by_common_expr.end(),
_schema->version_col_idx()) !=
- _non_predicate_column_ids.end()) {
+ _cols_read_by_common_expr.end()) {
_replace_version_col(selected_size);
}
-
RETURN_IF_ERROR(_convert_to_expected_type(_non_predicate_column_ids));
- for (auto cid : _non_predicate_column_ids) {
+
RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_common_expr));
+ for (auto cid : _cols_read_by_common_expr) {
auto loc = _schema_block_id_map[cid];
block->replace_by_position(loc,
std::move(_current_return_columns[cid]));
}
+
+ for (const auto pair : _vir_cid_to_idx_in_block) {
+ auto cid = pair.first;
+ auto loc = pair.second;
+ block->replace_by_position(loc,
+
std::move(_current_return_columns[cid]));
+ }
}
DCHECK(block->columns() >
_schema_block_id_map[*_common_expr_columns.begin()]);
- // block->rows() takes the size of the first column by
default. If the first column is no predicate column,
+ // block->rows() takes the size of the first column by
default.
+ // If the first column is not predicate column,
// it has not been read yet. add a const column that has
been read to calculate rows().
if (block->rows() == 0) {
vectorized::MutableColumnPtr col0 =
@@ -2430,17 +2517,23 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
}
}
} else if (_is_need_expr_eval) {
-
RETURN_IF_ERROR(_convert_to_expected_type(_non_predicate_column_ids));
- for (auto cid : _non_predicate_column_ids) {
+
RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_common_expr));
+ for (auto cid : _cols_read_by_common_expr) {
auto loc = _schema_block_id_map[cid];
block->replace_by_position(loc,
std::move(_current_return_columns[cid]));
}
+
+ for (const auto pair : _vir_cid_to_idx_in_block) {
+ auto cid = pair.first;
+ auto loc = pair.second;
+ block->replace_by_position(loc,
std::move(_current_return_columns[cid]));
+ }
}
} else if (_is_need_expr_eval) {
- DCHECK(!_predicate_column_ids.empty());
- RETURN_IF_ERROR(_convert_to_expected_type(_predicate_column_ids));
+ DCHECK(!_cols_read_by_column_predicate.empty());
+
RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_column_predicate));
// first read all rows are insert block, initialize sel_rowid_idx
to all rows.
- for (auto cid : _predicate_column_ids) {
+ for (auto cid : _cols_read_by_column_predicate) {
auto loc = _schema_block_id_map[cid];
block->replace_by_position(loc,
std::move(_current_return_columns[cid]));
}
@@ -2482,7 +2575,7 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
_selected_size = selected_size;
}
- if (_non_predicate_columns.empty()) {
+ if (_cols_not_included_by_any_predicates.empty()) {
// shrink char_type suffix zero data
block->shrink_char_type_column_suffix_zero(_char_type_idx);
@@ -2490,16 +2583,17 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
}
// step4: read non_predicate column
if (selected_size > 0) {
- RETURN_IF_ERROR(_read_columns_by_rowids(_non_predicate_columns,
_block_rowids,
- _sel_rowid_idx.data(),
selected_size,
- &_current_return_columns));
- if (std::find(_non_predicate_columns.begin(),
_non_predicate_columns.end(),
- _schema->version_col_idx()) !=
_non_predicate_columns.end()) {
+
RETURN_IF_ERROR(_read_columns_by_rowids(_cols_not_included_by_any_predicates,
+ _block_rowids,
_sel_rowid_idx.data(),
+ selected_size,
&_current_return_columns));
+ if (std::find(_cols_not_included_by_any_predicates.begin(),
+ _cols_not_included_by_any_predicates.end(),
_schema->version_col_idx()) !=
+ _cols_not_included_by_any_predicates.end()) {
_replace_version_col(selected_size);
}
}
- RETURN_IF_ERROR(_convert_to_expected_type(_non_predicate_columns));
+
RETURN_IF_ERROR(_convert_to_expected_type(_cols_not_included_by_any_predicates));
// step5: output columns
_output_non_pred_columns(block);
}
@@ -2836,6 +2930,8 @@ Status
SegmentIterator::_materialization_of_virtual_column(vectorized::Block* bl
if (vectorized::check_and_get_column<const vectorized::ColumnNothing>(
block->get_by_position(idx_in_block).column.get())) {
+ LOG_INFO("Virtual column is doing materialization, cid {},
column_expr {}", cid,
+ column_expr->root()->debug_string());
int result_cid = -1;
RETURN_IF_ERROR(column_expr->execute(block, &result_cid));
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.h
b/be/src/olap/rowset/segment_v2/segment_iterator.h
index 378e6e90630..bb868d6278a 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.h
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.h
@@ -396,7 +396,7 @@ private:
// whether lazy materialization read should be used.
bool _lazy_materialization_read;
// columns to read after predicate evaluation and remaining expr execute
- std::vector<ColumnId> _non_predicate_columns;
+ std::vector<ColumnId> _cols_not_included_by_any_predicates;
std::set<ColumnId> _common_expr_columns;
// remember the rowids we've read for the current row block.
// could be a local variable of next_batch(), kept here to reuse vector
memory
@@ -410,7 +410,7 @@ private:
_vec_pred_column_ids; // keep columnId of columns for vectorized
predicate evaluation
std::vector<ColumnId>
_short_cir_pred_column_ids; // keep columnId of columns for short
circuit predicate evaluation
- std::vector<bool> _is_pred_column; // columns hold _init segmentIter
+
std::map<uint32_t, bool> _need_read_data_indices;
std::vector<bool> _is_common_expr_column;
vectorized::MutableColumns _current_return_columns;
@@ -422,8 +422,9 @@ private:
// first, read predicate columns by various index
// second, read non-predicate columns
// so we need a field to stand for columns first time to read
- std::vector<ColumnId> _predicate_column_ids;
- std::vector<ColumnId> _non_predicate_column_ids;
+ std::vector<ColumnId> _cols_read_by_column_predicate;
+ std::vector<bool> _is_pred_column;
+ std::vector<ColumnId> _cols_read_by_common_expr;
std::vector<ColumnId> _columns_to_filter;
std::vector<ColumnId> _converted_column_ids;
std::vector<int> _schema_block_id_map; // map from schema column id to
column idx in Block
diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp
b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp
index 767e03f89c3..82cb4631cd3 100644
--- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp
@@ -49,13 +49,52 @@ void
VirtualColumnIterator::prepare_materialization(vectorized::IColumn::Ptr col
_filter = doris::vectorized::IColumn::Filter(_size, 0);
}
+Status VirtualColumnIterator::seek_to_first() {
+ if (_size < 0) {
+ // _materialized_column is not set. do nothing.
+ return Status::OK();
+ }
+ _current_ordinal = 0;
+
+ return Status::OK();
+}
+
+Status VirtualColumnIterator::seek_to_ordinal(ordinal_t ord_idx) {
+ if (_size < 0 ||
+
vectorized::check_and_get_column<vectorized::ColumnNothing>(*_materialized_column_ptr))
{
+ // _materialized_column is not set. do nothing.
+ return Status::OK();
+ }
+
+ if (ord_idx >= _size) {
+ return Status::InternalError("Seek to ordinal out of range: {} out of
{}", ord_idx, _size);
+ }
+
+ _current_ordinal = ord_idx;
+
+ return Status::OK();
+}
+
// Next batch implementation
Status VirtualColumnIterator::next_batch(size_t* n,
vectorized::MutableColumnPtr& dst,
bool* has_null) {
if
(vectorized::check_and_get_column<vectorized::ColumnNothing>(*_materialized_column_ptr))
{
return Status::OK();
}
+ size_t rows_num_to_read = *n;
+ if (_row_id_to_idx.find(_current_ordinal) == _row_id_to_idx.end()) {
+ return Status::InternalError("Current ordinal {} not found in
row_id_to_idx map",
+ _current_ordinal);
+ }
+ size_t start = _row_id_to_idx[_current_ordinal];
+ // Update dst column
+ dst = _materialized_column_ptr->clone_empty();
+ dst->insert_range_from(*_materialized_column_ptr, start, rows_num_to_read);
+
+ LOG_INFO("Virtual column iterators, next_batch, rows reads: {}, dst size:
{}", rows_num_to_read,
+ dst->size());
+ _current_ordinal += rows_num_to_read;
return Status::OK();
}
@@ -75,9 +114,10 @@ Status VirtualColumnIterator::read_by_rowids(const rowid_t*
rowids, const size_t
// Apply filter to materialized column
doris::vectorized::IColumn::Ptr res_col =
_materialized_column_ptr->filter(_filter, 0);
// Update dst column
- dst->clear();
- dst->insert_range_from(*res_col, 0, res_col->size());
+ dst = res_col->assume_mutable();
+ LOG_INFO("Virtual column iterators, read_by_rowids, rowids size: {}, dst
size: {}", count,
+ dst->size());
return Status::OK();
}
diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.h
b/be/src/olap/rowset/segment_v2/virtual_column_iterator.h
index 17f60ec0e7a..f8c5f360716 100644
--- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.h
+++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.h
@@ -38,9 +38,9 @@ public:
Status init(const ColumnIteratorOptions& opts) override;
- Status seek_to_first() override { return Status::OK(); }
+ Status seek_to_first() override;
- Status seek_to_ordinal(ordinal_t ord_idx) override { return Status::OK(); }
+ Status seek_to_ordinal(ordinal_t ord_idx) override;
Status next_batch(size_t* n, vectorized::MutableColumnPtr& dst, bool*
has_null) override;
@@ -57,9 +57,10 @@ private:
vectorized::IColumn::Ptr _materialized_column_ptr;
// segment rowid to index in column.
std::map<uint64_t, uint64_t> _row_id_to_idx;
-
doris::vectorized::IColumn::Filter _filter;
size_t _size = 0;
+
+ ordinal_t _current_ordinal = 0;
};
} // namespace doris::segment_v2
\ No newline at end of file
diff --git a/be/src/pipeline/exec/olap_scan_operator.cpp
b/be/src/pipeline/exec/olap_scan_operator.cpp
index cfc05f36dd2..7861d95ec9d 100644
--- a/be/src/pipeline/exec/olap_scan_operator.cpp
+++ b/be/src/pipeline/exec/olap_scan_operator.cpp
@@ -569,11 +569,13 @@ Status OlapScanLocalState::init(RuntimeState* state,
LocalStateInfo& info) {
// order by 的表达式需要是一个 slot_ref,并且类型需要是虚拟列
DCHECK(ordering_expr.nodes[0].__isset.slot_ref);
DCHECK(ordering_expr.nodes[0].slot_ref.is_virtual_slot);
- size_t limit = olap_scan_node.ann_sort_limit;
+ DCHECK(olap_scan_node.ann_sort_info.is_asc_order.size() == 1);
+ const bool asc = olap_scan_node.ann_sort_info.is_asc_order[0];
+ const size_t limit = olap_scan_node.ann_sort_limit;
std::shared_ptr<vectorized::VExprContext> ordering_expr_ctx;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(ordering_expr,
ordering_expr_ctx));
_ann_topn_descriptor =
- vectorized::AnnTopNDescriptor::create_shared(limit,
ordering_expr_ctx);
+ vectorized::AnnTopNDescriptor::create_shared(asc, limit,
ordering_expr_ctx);
}
return ScanLocalState<OlapScanLocalState>::init(state, info);
diff --git a/be/src/pipeline/exec/operator.cpp
b/be/src/pipeline/exec/operator.cpp
index e42f581166c..4059de9c2a9 100644
--- a/be/src/pipeline/exec/operator.cpp
+++ b/be/src/pipeline/exec/operator.cpp
@@ -198,15 +198,15 @@ Status OperatorXBase::init(const TPlanNode& tnode,
RuntimeState* /*state*/) {
if (tnode.__isset.vconjunct) {
vectorized::VExprContextSPtr context;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(tnode.vconjunct,
context));
- LOG_INFO("Conjunct of {} is\n{}", _op_name,
- apache::thrift::ThriftDebugString(tnode.vconjunct));
+ // LOG_INFO("Conjunct of {} is\n{}", _op_name,
+ // apache::thrift::ThriftDebugString(tnode.vconjunct));
_conjuncts.emplace_back(context);
} else if (tnode.__isset.conjuncts) {
for (auto& conjunct : tnode.conjuncts) {
vectorized::VExprContextSPtr context;
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(conjunct,
context));
- LOG_INFO("Conjunct of {} is\n{}", _op_name,
- apache::thrift::ThriftDebugString(conjunct));
+ // LOG_INFO("Conjunct of {} is\n{}", _op_name,
+ // apache::thrift::ThriftDebugString(conjunct));
// // Write the conjunct to a file for debugging
// doris::vectorized::write_to_json(
//
"/mnt/disk4/hezhiqiang/workspace/doris/cmaster/RELEASE/be1", "conjunct.json",
diff --git a/be/src/runtime/descriptors.cpp b/be/src/runtime/descriptors.cpp
index 315829295ae..3d0d4c03727 100644
--- a/be/src/runtime/descriptors.cpp
+++ b/be/src/runtime/descriptors.cpp
@@ -66,6 +66,26 @@ SlotDescriptor::SlotDescriptor(const TSlotDescriptor& tdesc)
_is_auto_increment(tdesc.__isset.is_auto_increment ?
tdesc.is_auto_increment : false),
_col_default_value(tdesc.__isset.col_default_value ?
tdesc.col_default_value : "") {
if (tdesc.__isset.virtual_column_expr) {
+ // Make sure virtual column is valid.
+ if (tdesc.virtual_column_expr.nodes.empty()) {
+ LOG_ERROR("Virtual column expr node is empty, col_name={},
col_unique_id={}",
+ tdesc.colName, tdesc.col_unique_id);
+
+ throw doris::Exception(doris::ErrorCode::FATAL_ERROR,
+ "Virtual column expr node is empty,
col_name: {}, "
+ "col_unique_id: {}",
+ tdesc.colName, tdesc.col_unique_id);
+ }
+ const auto& node = tdesc.virtual_column_expr.nodes[0];
+ if (node.node_type == TExprNodeType::SLOT_REF) {
+ LOG_ERROR(
+ "Virtual column expr node is slot ref, col_name={},
col_unique_id={}, expr: {}",
+ tdesc.colName, tdesc.col_unique_id,
apache::thrift::ThriftDebugString(tdesc));
+ throw doris::Exception(doris::ErrorCode::FATAL_ERROR,
+ "Virtual column expr node is slot ref,
col_name: {}, "
+ "col_unique_id: {}",
+ tdesc.colName, tdesc.col_unique_id);
+ }
this->virtual_column_expr =
std::make_shared<doris::TExpr>(tdesc.virtual_column_expr);
}
}
diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h
index e4ecf59563c..dd9c25144f5 100644
--- a/be/src/runtime/runtime_state.h
+++ b/be/src/runtime/runtime_state.h
@@ -49,7 +49,7 @@
#include "runtime/workload_group/workload_group.h"
#include "util/debug_util.h"
#include "util/runtime_profile.h"
-#include "vec/columns/columns_number.h"
+#include "vec/runtime/vector_search_user_params.h"
namespace doris {
class RuntimeFilter;
@@ -657,6 +657,12 @@ public:
int profile_level() const { return _profile_level; }
+ VectorSearchUserParams get_vector_search_params() const {
+ return VectorSearchUserParams(_query_options.hnsw_ef_search,
+
_query_options.hnsw_check_relative_distance,
+ _query_options.hnsw_bounded_queue);
+ }
+
private:
Status create_error_log_file();
diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp
index 4263a0659b7..fb7e218e062 100644
--- a/be/src/vec/core/block.cpp
+++ b/be/src/vec/core/block.cpp
@@ -777,6 +777,7 @@ void Block::update_hash(SipHash& hash) const {
}
}
+// columns_to_filter 实际上是需要进行过滤的 col 的 position
void Block::filter_block_internal(Block* block, const std::vector<uint32_t>&
columns_to_filter,
const IColumn::Filter& filter) {
size_t count = filter.size() -
simd::count_zero_num((int8_t*)filter.data(), filter.size());
diff --git a/be/src/vec/exec/scan/olap_scanner.cpp
b/be/src/vec/exec/scan/olap_scanner.cpp
index 63e5adff4a0..9c6197cce0e 100644
--- a/be/src/vec/exec/scan/olap_scanner.cpp
+++ b/be/src/vec/exec/scan/olap_scanner.cpp
@@ -104,6 +104,7 @@ OlapScanner::OlapScanner(pipeline::ScanLocalStateBase*
parent, OlapScanner::Para
}) {
_tablet_reader_params.set_read_source(std::move(params.read_source));
_is_init = false;
+ _vector_search_params = params.state->get_vector_search_params();
}
static std::string read_columns_to_string(TabletSchemaSPtr tablet_schema,
@@ -143,11 +144,12 @@ Status OlapScanner::init() {
auto* local_state =
static_cast<pipeline::OlapScanLocalState*>(_local_state);
auto& tablet = _tablet_reader_params.tablet;
auto& tablet_schema = _tablet_reader_params.tablet_schema;
+
for (auto ctx : local_state->_common_expr_ctxs_push_down) {
VExprContextSPtr context;
RETURN_IF_ERROR(ctx->clone(_state, context));
_common_expr_ctxs_push_down.emplace_back(context);
- RETURN_IF_ERROR(context->prepare_ann_range_search());
+
RETURN_IF_ERROR(context->prepare_ann_range_search(_vector_search_params));
}
for (auto pair : local_state->_slot_id_to_virtual_column_expr) {
diff --git a/be/src/vec/exec/scan/olap_scanner.h
b/be/src/vec/exec/scan/olap_scanner.h
index 0fbeedb16a1..f6895662c88 100644
--- a/be/src/vec/exec/scan/olap_scanner.h
+++ b/be/src/vec/exec/scan/olap_scanner.h
@@ -36,6 +36,7 @@
#include "olap/tablet.h"
#include "olap/tablet_reader.h"
#include "olap/tablet_schema.h"
+#include "runtime/runtime_state.h"
#include "vec/data_types/data_type.h"
#include "vec/exec/scan/scanner.h"
@@ -118,6 +119,8 @@ public:
std::map<size_t, vectorized::DataTypePtr> _vir_col_idx_to_type;
std::shared_ptr<vectorized::AnnTopNDescriptor> _ann_topn_descriptor;
+
+ VectorSearchUserParams _vector_search_params;
};
} // namespace vectorized
} // namespace doris
diff --git a/be/src/vec/exprs/ann_range_search_params.h
b/be/src/vec/exprs/ann_range_search_params.h
index 5eedd7d334c..410c4dc14c4 100644
--- a/be/src/vec/exprs/ann_range_search_params.h
+++ b/be/src/vec/exprs/ann_range_search_params.h
@@ -22,18 +22,21 @@
#include <string>
#include "olap/rowset/segment_v2/ann_index_iterator.h"
+#include "runtime/runtime_state.h"
+#include "vector/vector_index.h"
namespace doris::vectorized {
-struct AnnRangeSearchParams {
+struct RangeSearchRuntimeInfo {
bool is_ann_range_search = false;
bool is_le_or_lt = true;
size_t src_col_idx = 0;
int64_t dst_col_idx = -1;
double radius = 0.0;
- int ef_search = 0;
+ segment_v2::VectorIndex::Metric metric_type;
+ doris::VectorSearchUserParams user_params;
std::unique_ptr<float[]> query_value;
- segment_v2::RangeSearchParams toRangeSearchParams() {
+ segment_v2::RangeSearchParams to_range_search_params() {
segment_v2::RangeSearchParams params;
params.query_value = query_value.get();
params.radius = static_cast<float>(radius);
@@ -42,17 +45,13 @@ struct AnnRangeSearchParams {
return params;
}
- segment_v2::CustomSearchParams toCustomSearchParams() {
- segment_v2::CustomSearchParams params;
- params.ef_search = ef_search;
- return params;
- }
-
std::string to_string() const {
return fmt::format(
"is_ann_range_search: {}, is_le_or_lt: {}, src_col_idx: {}, "
- "dst_col_idx: {}, radius: {}, ef_search: {}",
- is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx,
radius, ef_search);
+ "dst_col_idx: {}, metric_type {}, radius: {}, user params: {}",
+ is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx,
+ segment_v2::VectorIndex::metric_to_string(metric_type), radius,
+ user_params.to_string());
}
};
} // namespace doris::vectorized
diff --git a/be/src/vec/exprs/vann_topn_predicate.cpp
b/be/src/vec/exprs/vann_topn_predicate.cpp
index d68086dcd3d..30352b09782 100644
--- a/be/src/vec/exprs/vann_topn_predicate.cpp
+++ b/be/src/vec/exprs/vann_topn_predicate.cpp
@@ -102,7 +102,16 @@ Status AnnTopNDescriptor::prepare(RuntimeState* state,
const RowDescriptor& row_
distance_fn_call->children()[1]->debug_string());
}
_query_array = array_literal->get_column_ptr();
+ _user_params = state->get_vector_search_params();
+ std::set<std::string> distance_func_names = {vectorized::L2Distance::name,
+
vectorized::InnerProduct::name};
+ if (distance_func_names.contains(distance_fn_call->function_name()) ==
false) {
+ return Status::InternalError("Ann topn expr expect distance function,
got {}",
+ distance_fn_call->function_name());
+ }
+
+ _metric_type =
segment_v2::VectorIndex::string_to_metric(distance_fn_call->function_name());
VLOG_DEBUG << "AnnTopNDescriptor: {}" << this->debug_string();
return Status::OK();
}
@@ -112,6 +121,9 @@ Status AnnTopNDescriptor::evaluate_vector_ann_search(
vectorized::IColumn::MutablePtr& result_column,
std::unique_ptr<std::vector<uint64_t>>& row_ids) {
DCHECK(ann_index_iterator != nullptr);
+ segment_v2::AnnIndexIterator* ann_index_iterator_casted =
+ dynamic_cast<segment_v2::AnnIndexIterator*>(ann_index_iterator);
+ DCHECK(ann_index_iterator_casted != nullptr);
DCHECK(_order_by_expr_ctx != nullptr);
DCHECK(_order_by_expr_ctx->root() != nullptr);
@@ -135,6 +147,7 @@ Status AnnTopNDescriptor::evaluate_vector_ann_search(
.query_value = query_value_f32.get(),
.query_value_size = query_value_size,
.limit = _limit,
+ ._user_params = _user_params,
.roaring = &roaring,
.distance = nullptr,
.row_ids = nullptr,
@@ -159,7 +172,9 @@ Status AnnTopNDescriptor::evaluate_vector_ann_search(
std::string AnnTopNDescriptor::debug_string() const {
return "AnnTopNDescriptor: limit=" + std::to_string(_limit) +
", src_col_idx=" + std::to_string(_src_column_idx) +
- ", dest_col_idx=" + std::to_string(_dest_column_idx) +
+ ", dest_col_idx=" + std::to_string(_dest_column_idx) + ", asc=" +
std::to_string(_asc) +
+ ", user_params=" + _user_params.to_string() +
+ ", metric_type=" +
segment_v2::VectorIndex::metric_to_string(_metric_type) +
", order_by_expr=" + _order_by_expr_ctx->root()->debug_string();
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/exprs/vann_topn_predicate.h
b/be/src/vec/exprs/vann_topn_predicate.h
index fb92dfa0a8b..842d054cfcd 100644
--- a/be/src/vec/exprs/vann_topn_predicate.h
+++ b/be/src/vec/exprs/vann_topn_predicate.h
@@ -17,6 +17,7 @@
#pragma once
+#include "runtime/runtime_state.h"
#include "vec/columns/column.h"
#include "vec/exprs/varray_literal.h"
#include "vec/exprs/vcast_expr.h"
@@ -32,8 +33,8 @@ class AnnTopNDescriptor {
ENABLE_FACTORY_CREATOR(AnnTopNDescriptor);
public:
- AnnTopNDescriptor(size_t limit, VExprContextSPtr order_by_expr_ctx)
- : _limit(limit), _order_by_expr_ctx(order_by_expr_ctx) {};
+ AnnTopNDescriptor(bool asc, size_t limit, VExprContextSPtr
order_by_expr_ctx)
+ : _asc(asc), _limit(limit), _order_by_expr_ctx(order_by_expr_ctx)
{};
Status prepare(RuntimeState* state, const RowDescriptor& row_desc);
@@ -43,14 +44,16 @@ public:
roaring::Roaring& row_bitmap,
vectorized::IColumn::MutablePtr&
result_column,
std::unique_ptr<std::vector<uint64_t>>&
row_ids);
-
+ segment_v2::VectorIndex::Metric get_metric_type() const { return
_metric_type; }
std::string debug_string() const;
size_t get_src_column_idx() const { return _src_column_idx; }
size_t get_dest_column_idx() const { return _dest_column_idx; }
+ bool is_asc() const { return _asc; }
private:
+ const bool _asc;
// limit N
const size_t _limit;
// order by distance(xxx, [1,2])
@@ -59,7 +62,9 @@ private:
std::string _name = "AnnTopNDescriptor";
size_t _src_column_idx = -1;
size_t _dest_column_idx = -1;
+ segment_v2::VectorIndex::Metric _metric_type;
IColumn::Ptr _query_array;
+ doris::VectorSearchUserParams _user_params;
};
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp
b/be/src/vec/exprs/vectorized_fn_call.cpp
index e3408aac960..9ed006195c3 100644
--- a/be/src/vec/exprs/vectorized_fn_call.cpp
+++ b/be/src/vec/exprs/vectorized_fn_call.cpp
@@ -57,6 +57,7 @@
#include "vec/functions/function_rpc.h"
#include "vec/functions/simple_function_factory.h"
#include "vec/utils/util.hpp"
+#include "vector/vector_index.h"
namespace doris {
class RowDescriptor;
@@ -258,6 +259,10 @@ const std::string& VectorizedFnCall::expr_name() const {
return _expr_name;
}
+std::string VectorizedFnCall::function_name() const {
+ return _function_name;
+}
+
std::string VectorizedFnCall::debug_string() const {
std::stringstream out;
out << "VectorizedFn[";
@@ -327,7 +332,8 @@ bool VectorizedFnCall::equals(const VExpr& other) {
SlotRef
*/
-Status VectorizedFnCall::prepare_ann_range_search() {
+Status VectorizedFnCall::prepare_ann_range_search(
+ const doris::VectorSearchUserParams& user_params) {
std::set<TExprOpcode::type> ops = {TExprOpcode::GE, TExprOpcode::LE,
TExprOpcode::LE,
TExprOpcode::GT, TExprOpcode::LT};
if (ops.find(this->op()) == ops.end()) {
@@ -376,9 +382,13 @@ Status VectorizedFnCall::prepare_ann_range_search() {
}
// Now left child is a function call, we need to check if it is a distance
function
- if (function_call->_function_name != L2Distance::name) {
+ std::set<std::string> distance_functions = {L2Distance::name,
InnerProduct::name};
+ if (distance_functions.find(function_call->_function_name) ==
distance_functions.end()) {
LOG_INFO("Left child is not a distance function. Got {}",
function_call->_function_name);
return Status::OK();
+ } else {
+ _ann_range_search_params.metric_type =
+
segment_v2::VectorIndex::string_to_metric(function_call->_function_name);
}
if (function_call->get_num_children() != 2) {
@@ -430,6 +440,7 @@ Status VectorizedFnCall::prepare_ann_range_search() {
_ann_range_search_params.query_value[i] =
static_cast<Float32>(cf64->get_data()[i]);
}
_ann_range_search_params.is_ann_range_search = true;
+ _ann_range_search_params.user_params = user_params;
LOG_INFO("Ann range search params: {}",
_ann_range_search_params.to_string());
return Status::OK();
}
@@ -452,26 +463,37 @@ Status VectorizedFnCall::evaluate_ann_range_search(
ColumnId src_col_cid = idx_to_cid[idx_in_block];
DCHECK(src_col_cid < cid_to_index_iterators.size());
- segment_v2::IndexIterator* index_iterators =
cid_to_index_iterators[src_col_cid].get();
- if (index_iterators == nullptr) {
+ segment_v2::IndexIterator* index_iterator =
cid_to_index_iterators[src_col_cid].get();
+ if (index_iterator == nullptr) {
LOG_INFO("No index iterator for column cid {}", src_col_cid);
return Status::OK();
}
- segment_v2::AnnIndexIterator* ann_index_iterators =
- dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterators);
- if (ann_index_iterators == nullptr) {
+ segment_v2::AnnIndexIterator* ann_index_iterator =
+ dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator);
+ if (ann_index_iterator == nullptr) {
LOG_INFO("No index iterator for column cid {}", src_col_cid);
return Status::OK();
}
+ DCHECK(ann_index_iterator->get_reader() != nullptr)
+ << "Ann index iterator should have reader. Column cid: " <<
src_col_cid;
+ std::shared_ptr<AnnIndexReader> ann_index_reader =
+
std::dynamic_pointer_cast<AnnIndexReader>(ann_index_iterator->get_reader());
+ DCHECK(ann_index_reader != nullptr)
+ << "Ann index reader should not be null. Column cid: " <<
src_col_cid;
+ // Check if metrics type is match.
+ if (ann_index_reader->get_metric_type() !=
_ann_range_search_params.metric_type) {
+ LOG_INFO("Metric type not match, can not execute range search by
index.");
+ return Status::OK();
+ }
- RangeSearchParams params = _ann_range_search_params.toRangeSearchParams();
- CustomSearchParams custom_params =
_ann_range_search_params.toCustomSearchParams();
+ RangeSearchParams params =
_ann_range_search_params.to_range_search_params();
params.roaring = &row_bitmap;
DCHECK(params.roaring != nullptr);
RangeSearchResult result;
- RETURN_IF_ERROR(ann_index_iterators->range_search(params, custom_params,
&result));
+ RETURN_IF_ERROR(ann_index_iterator->range_search(params,
_ann_range_search_params.user_params,
+ &result));
#ifndef NDEBUG
if (this->_ann_range_search_params.is_le_or_lt == false) {
diff --git a/be/src/vec/exprs/vectorized_fn_call.h
b/be/src/vec/exprs/vectorized_fn_call.h
index 4f3f76b5436..14d86964ac9 100644
--- a/be/src/vec/exprs/vectorized_fn_call.h
+++ b/be/src/vec/exprs/vectorized_fn_call.h
@@ -22,6 +22,7 @@
#include <vector>
#include "common/status.h"
+#include "runtime/runtime_state.h"
#include "udf/udf.h"
#include "vec/core/column_numbers.h"
#include "vec/exprs/ann_range_search_params.h"
@@ -60,6 +61,7 @@ public:
FunctionContext::FunctionStateScope scope) override;
void close(VExprContext* context, FunctionContext::FunctionStateScope
scope) override;
const std::string& expr_name() const override;
+ std::string function_name() const;
std::string debug_string() const override;
bool is_constant() const override {
if (!_function->is_use_default_implementation_for_constants() ||
@@ -82,13 +84,13 @@ public:
const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>&
column_iterators,
roaring::Roaring& row_bitmap) override;
- Status prepare_ann_range_search() override;
+ Status prepare_ann_range_search(const doris::VectorSearchUserParams&
params) override;
protected:
FunctionBasePtr _function;
std::string _expr_name;
std::string _function_name;
- AnnRangeSearchParams _ann_range_search_params;
+ RangeSearchRuntimeInfo _ann_range_search_params;
private:
Status _do_execute(doris::vectorized::VExprContext* context,
doris::vectorized::Block* block,
diff --git a/be/src/vec/exprs/vexpr.cpp b/be/src/vec/exprs/vexpr.cpp
index 794bd66fac5..f3269dfca36 100644
--- a/be/src/vec/exprs/vexpr.cpp
+++ b/be/src/vec/exprs/vexpr.cpp
@@ -807,9 +807,9 @@ Status VExpr::evaluate_ann_range_search(
return Status::OK();
}
-Status VExpr::prepare_ann_range_search() {
+Status VExpr::prepare_ann_range_search(const doris::VectorSearchUserParams&
params) {
for (auto& child : _children) {
- RETURN_IF_ERROR(child->prepare_ann_range_search());
+ RETURN_IF_ERROR(child->prepare_ann_range_search(params));
}
return Status::OK();
}
diff --git a/be/src/vec/exprs/vexpr.h b/be/src/vec/exprs/vexpr.h
index c78cadb803a..09b91c35114 100644
--- a/be/src/vec/exprs/vexpr.h
+++ b/be/src/vec/exprs/vexpr.h
@@ -283,7 +283,7 @@ public:
const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>&
column_iterators,
roaring::Roaring& row_bitmap);
- virtual Status prepare_ann_range_search();
+ virtual Status prepare_ann_range_search(const
doris::VectorSearchUserParams& params);
bool has_been_executed();
diff --git a/be/src/vec/exprs/vexpr_context.cpp
b/be/src/vec/exprs/vexpr_context.cpp
index 66c5c1bfc58..886dea256c5 100644
--- a/be/src/vec/exprs/vexpr_context.cpp
+++ b/be/src/vec/exprs/vexpr_context.cpp
@@ -432,11 +432,11 @@ void VExprContext::_reset_memory_usage(const
VExprContextSPtrs& contexts) {
[](auto&& context) { context->_memory_usage = 0; });
}
-Status VExprContext::prepare_ann_range_search() {
+Status VExprContext::prepare_ann_range_search(const
doris::VectorSearchUserParams& params) {
if (_root == nullptr) {
return Status::OK();
}
- return _root->prepare_ann_range_search();
+ return _root->prepare_ann_range_search(params);
}
#include "common/compile_check_end.h"
diff --git a/be/src/vec/exprs/vexpr_context.h b/be/src/vec/exprs/vexpr_context.h
index 60d3e1a31e0..d43012353e5 100644
--- a/be/src/vec/exprs/vexpr_context.h
+++ b/be/src/vec/exprs/vexpr_context.h
@@ -28,6 +28,7 @@
#include "common/factory_creator.h"
#include "common/status.h"
#include "olap/rowset/segment_v2/inverted_index_reader.h"
+#include "runtime/runtime_state.h"
#include "runtime/types.h"
#include "udf/udf.h"
#include "vec/core/block.h"
@@ -279,7 +280,7 @@ public:
[[nodiscard]] size_t get_memory_usage() const { return _memory_usage; }
- Status prepare_ann_range_search();
+ Status prepare_ann_range_search(const doris::VectorSearchUserParams&
params);
private:
// Close method is called in vexpr context dector, not need call expicility
diff --git a/be/src/vec/exprs/virtual_slot_ref.cpp
b/be/src/vec/exprs/virtual_slot_ref.cpp
index 2455ef40bfa..844da622f31 100644
--- a/be/src/vec/exprs/virtual_slot_ref.cpp
+++ b/be/src/vec/exprs/virtual_slot_ref.cpp
@@ -91,7 +91,7 @@ Status VirtualSlotRef::prepare(doris::RuntimeState* state,
const doris::RowDescr
state->desc_tbl().debug_string());
}
const TExpr& expr = *slot_desc->get_virtual_column_expr();
- LOG_INFO("Virtual column expr is {}",
apache::thrift::ThriftDebugString(expr));
+ // LOG_INFO("Virtual column expr is {}",
apache::thrift::ThriftDebugString(expr));
// Create a temp_ctx only for create_expr_tree.
VExprContextSPtr temp_ctx;
RETURN_IF_ERROR(VExpr::create_expr_tree(expr, temp_ctx));
diff --git a/be/src/vec/functions/array/function_array_distance.h
b/be/src/vec/functions/array/function_array_distance.h
index 28b0df28d7f..fcb7a067a07 100644
--- a/be/src/vec/functions/array/function_array_distance.h
+++ b/be/src/vec/functions/array/function_array_distance.h
@@ -96,6 +96,9 @@ public:
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const
override {
+ LOG_INFO("Function {} is executed with {} rows, stack {}", get_name(),
input_rows_count,
+ doris::get_stack_trace());
+
const auto& arg1 = block.get_by_position(arguments[0]);
const auto& arg2 = block.get_by_position(arguments[1]);
if (!_check_input_type(arg1.type) || !_check_input_type(arg2.type)) {
diff --git a/be/src/vec/runtime/vector_search_user_params.cpp
b/be/src/vec/runtime/vector_search_user_params.cpp
new file mode 100644
index 00000000000..04b8c8b91c4
--- /dev/null
+++ b/be/src/vec/runtime/vector_search_user_params.cpp
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/runtime/vector_search_user_params.h"
+
+#include <fmt/format.h>
+
+namespace doris {
+bool VectorSearchUserParams::operator==(const VectorSearchUserParams& other)
const {
+ return hnsw_ef_search == other.hnsw_ef_search &&
+ hnsw_check_relative_distance == other.hnsw_check_relative_distance
&&
+ hnsw_bounded_queue == other.hnsw_bounded_queue;
+}
+
+std::string VectorSearchUserParams::to_string() const {
+ return fmt::format(
+ "hnsw_ef_search: {}, hnsw_check_relative_distance: {}, "
+ "hnsw_bounded_queue: {}",
+ hnsw_ef_search, hnsw_check_relative_distance, hnsw_bounded_queue);
+}
+} // namespace doris
\ No newline at end of file
diff --git a/be/src/vec/runtime/vector_search_user_params.h
b/be/src/vec/runtime/vector_search_user_params.h
new file mode 100644
index 00000000000..5f886405e06
--- /dev/null
+++ b/be/src/vec/runtime/vector_search_user_params.h
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+
+namespace doris {
+// Constructed from session variables.
+struct VectorSearchUserParams {
+ int hnsw_ef_search = 16;
+ bool hnsw_check_relative_distance = true;
+ bool hnsw_bounded_queue = true;
+
+ bool operator==(const VectorSearchUserParams& other) const;
+
+ std::string to_string() const;
+};
+} // namespace doris
\ No newline at end of file
diff --git a/be/src/vector/faiss_vector_index.cpp
b/be/src/vector/faiss_vector_index.cpp
index e5916973ece..f48c71334c5 100644
--- a/be/src/vector/faiss_vector_index.cpp
+++ b/be/src/vector/faiss_vector_index.cpp
@@ -129,9 +129,48 @@ doris::Status FaissVectorIndex::add(int n, const float*
vec) {
void FaissVectorIndex::set_build_params(const FaissBuildParameter& params) {
_dimension = params.d;
if (params.index_type == FaissBuildParameter::IndexType::BruteForce) {
- _index = std::make_unique<faiss::IndexFlatL2>(params.d);
+ if (params.metric_type == FaissBuildParameter::MetricType::L2) {
+ _index = std::make_unique<faiss::IndexFlatL2>(params.d);
+ } else if (params.metric_type == FaissBuildParameter::MetricType::IP) {
+ _index = std::make_unique<faiss::IndexFlatIP>(params.d);
+ } else {
+ throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
+ "Unsupported metric type: {}",
+ static_cast<int>(params.metric_type));
+ }
} else if (params.index_type == FaissBuildParameter::IndexType::HNSW) {
- _index = std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m);
+ if (params.quantilizer == FaissBuildParameter::Quantilizer::FLAT) {
+ if (params.metric_type == FaissBuildParameter::MetricType::L2) {
+ _index = std::make_unique<faiss::IndexHNSWFlat>(params.d,
params.m);
+ } else if (params.metric_type ==
FaissBuildParameter::MetricType::IP) {
+ _index = std::make_unique<faiss::IndexHNSWFlat>(params.d,
params.m,
+
faiss::METRIC_INNER_PRODUCT);
+ } else {
+ throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
+ "Unsupported metric type: {}",
+ static_cast<int>(params.metric_type));
+ }
+ } else if (params.quantilizer == FaissBuildParameter::Quantilizer::PQ)
{
+ if (params.pq_m <= 0) {
+ throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
+ "pq_m should be greater than 0 for PQ
quantilizer");
+ }
+
+ if (params.metric_type == FaissBuildParameter::MetricType::L2) {
+ _index = std::make_unique<faiss::IndexHNSWPQ>(params.d,
params.m, params.pq_m);
+ } else if (params.metric_type ==
FaissBuildParameter::MetricType::IP) {
+ _index = std::make_unique<faiss::IndexHNSWPQ>(params.d,
params.m, params.pq_m,
+
faiss::METRIC_INNER_PRODUCT);
+ } else {
+ throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
+ "Unsupported metric type: {}",
+ static_cast<int>(params.metric_type));
+ }
+ } else {
+ throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
+ "Unsupported quantilizer type: {}",
+ static_cast<int>(params.quantilizer));
+ }
} else {
throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
"Unsupported index type: {}",
static_cast<int>(params.index_type));
@@ -158,7 +197,16 @@ doris::Status FaissVectorIndex::ann_topn_search(const
float* query_vec, int k,
std::unique_ptr<faiss::IDSelector> id_sel = nullptr;
id_sel = roaring_to_faiss_selector(*params.roaring);
faiss::SearchParametersHNSW param;
+ const HNSWSearchParameters* hnsw_params =
+ dynamic_cast<const HNSWSearchParameters*>(¶ms);
+ if (hnsw_params == nullptr) {
+ return doris::Status::InvalidArgument(
+ "HNSW search parameters should not be null for HNSW
index");
+ }
param.sel = id_sel.get();
+ param.efSearch = hnsw_params->ef_search;
+ param.check_relative_distance = hnsw_params->check_relative_distance;
+ param.bounded_queue = hnsw_params->bounded_queue;
_index->search(1, query_vec, k, distances, labels, ¶m);
}
@@ -193,6 +241,8 @@ doris::Status FaissVectorIndex::range_search(const float*
query_vec, const float
if (hnsw_params != nullptr) {
faiss::SearchParametersHNSW param;
param.efSearch = hnsw_params->ef_search;
+ param.check_relative_distance = hnsw_params->check_relative_distance;
+ param.bounded_queue = hnsw_params->bounded_queue;
param.sel = sel ? sel.get() : nullptr;
_index->range_search(1, query_vec, radius * radius,
&native_search_result, ¶m);
} else {
diff --git a/be/src/vector/faiss_vector_index.h
b/be/src/vector/faiss_vector_index.h
index 53637344bc1..129a6b26ccb 100644
--- a/be/src/vector/faiss_vector_index.h
+++ b/be/src/vector/faiss_vector_index.h
@@ -26,13 +26,20 @@
#include <string>
#include "common/status.h"
+#include "util/metrics.h"
#include "vector_index.h"
namespace doris::segment_v2 {
struct FaissBuildParameter {
enum class IndexType { BruteForce, IVF, HNSW };
- enum class Quantilizer { FLAT, SQ, PQ };
+ enum class Quantilizer { FLAT, PQ };
+
+ enum class MetricType {
+ L2, // Euclidean distance
+ IP, // Inner product
+ COSINE // Cosine similarity
+ };
static IndexType string_to_index_type(const std::string& type) {
if (type == "brute_force") {
@@ -48,19 +55,30 @@ struct FaissBuildParameter {
static Quantilizer string_to_quantilizer(const std::string& type) {
if (type == "flat") {
return Quantilizer::FLAT;
- } else if (type == "sq") {
- return Quantilizer::SQ;
} else if (type == "pq") {
return Quantilizer::PQ;
}
return Quantilizer::FLAT; // default
}
+ static MetricType string_to_metric_type(const std::string& type) {
+ if (type == "l2") {
+ return MetricType::L2;
+ } else if (type == "ip") {
+ return MetricType::IP;
+ } else if (type == "cosine") {
+ return MetricType::COSINE;
+ }
+ return MetricType::L2; // default
+ }
+
// HNSW
int d = 0;
int m = 0;
+ int pq_m = -1; // Only used for PQ quantilizer
IndexType index_type;
Quantilizer quantilizer;
+ MetricType metric_type = MetricType::L2;
};
class FaissVectorIndex : public VectorIndex {
diff --git a/be/src/vector/vector_index.h b/be/src/vector/vector_index.h
index 50b0d59b624..e344f1c06c3 100644
--- a/be/src/vector/vector_index.h
+++ b/be/src/vector/vector_index.h
@@ -21,7 +21,7 @@
#include <roaring/roaring.hh>
#include "common/status.h"
-#include "gutil/integral_types.h"
+#include "vec/functions/array/function_array_distance.h"
namespace lucene::store {
class Directory;
@@ -50,11 +50,13 @@ struct IndexSearchParameters {
struct HNSWSearchParameters : public IndexSearchParameters {
int ef_search = 16;
+ bool check_relative_distance = true;
+ bool bounded_queue = true;
};
class VectorIndex {
public:
- enum class Metric { L2, COSINE, INNER_PRODUCT, UNKNOWN };
+ enum class Metric { L2, INNER_PRODUCT, UNKNOWN };
/** Add n vectors of dimension d to the index.
*
@@ -87,21 +89,17 @@ public:
static std::string metric_to_string(Metric metric) {
switch (metric) {
case Metric::L2:
- return "L2";
- case Metric::COSINE:
- return "COSINE";
+ return vectorized::L2Distance::name;
case Metric::INNER_PRODUCT:
- return "INNER_PRODUCT";
+ return vectorized::InnerProduct::name;
default:
return "UNKNOWN";
}
}
static Metric string_to_metric(const std::string& metric) {
- if (metric == "l2") {
+ if (metric == vectorized::L2Distance::name) {
return Metric::L2;
- } else if (metric == "cosine") {
- return Metric::COSINE;
- } else if (metric == "inner_product") {
+ } else if (metric == vectorized::InnerProduct::name) {
return Metric::INNER_PRODUCT;
} else {
return Metric::UNKNOWN;
@@ -112,6 +110,7 @@ public:
size_t get_dimension() const { return _dimension; }
protected:
+ // When adding vectors to the index, use this variable to check the
dimension of the vectors.
size_t _dimension = 0;
};
diff --git a/be/test/olap/vector_search/ann_index_reader_test.cpp
b/be/test/olap/vector_search/ann_index_reader_test.cpp
index e951a48c743..5b29f21e90c 100644
--- a/be/test/olap/vector_search/ann_index_reader_test.cpp
+++ b/be/test/olap/vector_search/ann_index_reader_test.cpp
@@ -15,33 +15,93 @@
// specific language governing permissions and limitations
// under the License.
-#include "olap/rowset/segment_v2/ann_index_reader.h"
-
+#include <gen_cpp/olap_file.pb.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include <iostream>
#include <memory>
+#include <string>
+#include "faiss_vector_index.h"
+#include "olap/rowset/segment_v2/ann_index_iterator.h"
+#include "olap/tablet_schema.h"
+#include "runtime/runtime_state.h"
#include "vector_search_utils.h"
+using namespace doris::vector_search_utils;
+
+namespace doris::vectorized {
+
+TEST_F(VectorSearchTest, AnnIndexReaderRangeSearch) {
+ size_t iterato = 25;
+ for (size_t i = 0; i < iterato; ++i) {
+ std::map<std::string, std::string> index_properties;
+ index_properties["index_type"] = "hnsw";
+ index_properties["metric_type"] = "l2";
+ std::unique_ptr<doris::TabletIndex> index_meta =
std::make_unique<doris::TabletIndex>();
+ index_meta->_properties = index_properties;
+ auto mock_index_file_reader = std::make_shared<MockIndexFileReader>();
+ auto ann_index_reader = std::make_unique<segment_v2::AnnIndexReader>(
+ index_meta.get(), mock_index_file_reader);
+ doris::vector_search_utils::IndexType index_type =
+ doris::vector_search_utils::IndexType::HNSW;
+ const size_t dim = 128;
+ const size_t m = 16;
+ auto doris_faiss_index =
doris::vector_search_utils::create_doris_index(index_type, dim, m);
+ auto native_faiss_index =
+ doris::vector_search_utils::create_native_index(index_type,
dim, m);
+ const size_t num_vectors = 1000;
+ auto vectors =
doris::vector_search_utils::generate_test_vectors_matrix(num_vectors, dim);
+ doris::vector_search_utils::add_vectors_to_indexes_serial_mode(
+ doris_faiss_index.get(), native_faiss_index.get(), vectors);
+ std::ignore = doris_faiss_index->save(this->_ram_dir.get());
+ std::vector<float> query_value = vectors[0];
+ const float radius =
doris::vector_search_utils::get_radius_from_matrix(query_value.data(),
+
dim, vectors, 0.3);
+
+ // Make sure all rows are in the roaring
+ auto roaring = std::make_unique<roaring::Roaring>();
+ for (size_t i = 0; i < num_vectors; ++i) {
+ roaring->add(i);
+ }
+
+ doris::segment_v2::RangeSearchParams params;
+ params.radius = radius;
+ params.query_value = query_value.data();
+ params.roaring = roaring.get();
+ doris::VectorSearchUserParams custom_params;
+ custom_params.hnsw_ef_search = 16;
+ doris::segment_v2::RangeSearchResult result;
+ auto doris_faiss_vector_index =
std::make_unique<doris::segment_v2::FaissVectorIndex>();
+ std::ignore = doris_faiss_vector_index->load(this->_ram_dir.get());
+ ann_index_reader->_vector_index = std::move(doris_faiss_vector_index);
+ std::ignore = ann_index_reader->range_search(params, custom_params,
&result, nullptr);
-namespace doris::segment_v2 {
+ ASSERT_TRUE(result.roaring != nullptr);
+ ASSERT_TRUE(result.distance != nullptr);
+ ASSERT_TRUE(result.row_ids != nullptr);
+ std::vector<std::pair<int, float>> doris_search_result_order_by_lables;
+ for (size_t i = 0; i < result.roaring->cardinality(); ++i) {
+ doris_search_result_order_by_lables.push_back(
+ {result.row_ids->at(i), result.distance[i]});
+ }
-using namespace vector_search_utils;
-class AnnIndexReaderTest : public testing::Test {};
+ std::sort(doris_search_result_order_by_lables.begin(),
+ doris_search_result_order_by_lables.end(),
+ [](const auto& a, const auto& b) { return a.first < b.first;
});
-TEST_F(AnnIndexReaderTest, TestLoadIndex) {
- MockTabletSchema tablet_schema;
- std::shared_ptr<MockIndexFileReader> index_file_reader =
- std::make_shared<MockIndexFileReader>();
- auto ann_index_reader = std::make_unique<AnnIndexReader>(&tablet_schema,
index_file_reader);
+ std::vector<std::pair<int, float>>
native_search_result_order_by_lables =
+ doris::vector_search_utils::perform_native_index_range_search(
+ native_faiss_index.get(), query_value.data(), radius);
- EXPECT_TRUE(ann_index_reader->load_index(nullptr).ok());
-}
+ ASSERT_EQ(result.roaring->cardinality(),
native_search_result_order_by_lables.size());
-TEST_F(AnnIndexReaderTest, TestQuery) {
- MockTabletSchema tablet_schema;
- std::shared_ptr<MockIndexFileReader> index_file_reader =
- std::make_shared<MockIndexFileReader>();
- auto ann_index_reader = std::make_unique<AnnIndexReader>(&tablet_schema,
index_file_reader);
-}
-} // namespace doris::segment_v2
\ No newline at end of file
+ for (size_t i = 0; i < native_search_result_order_by_lables.size();
++i) {
+ ASSERT_EQ(doris_search_result_order_by_lables[i].first,
+ native_search_result_order_by_lables[i].first);
+ ASSERT_FLOAT_EQ(doris_search_result_order_by_lables[i].second,
+ native_search_result_order_by_lables[i].second);
+ }
+ }
+};
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/test/olap/vector_search/ann_range_search_test.cpp
b/be/test/olap/vector_search/ann_range_search_test.cpp
index 8b9ad847f46..f5599526b38 100644
--- a/be/test/olap/vector_search/ann_range_search_test.cpp
+++ b/be/test/olap/vector_search/ann_range_search_test.cpp
@@ -28,7 +28,9 @@
#include "common/object_pool.h"
#include "olap/rowset/segment_v2/ann_index_iterator.h"
+#include "olap/rowset/segment_v2/ann_index_reader.h"
#include "olap/rowset/segment_v2/column_reader.h"
+#include "olap/rowset/segment_v2/index_file_reader.h"
#include "olap/rowset/segment_v2/virtual_column_iterator.h"
#include "olap/vector_search/vector_search_utils.h"
#include "runtime/descriptors.h"
@@ -826,10 +828,11 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) {
state->set_desc_tbl(desc_tbl_ptr);
VExprContextSPtr range_search_ctx;
+ doris::VectorSearchUserParams user_params;
ASSERT_TRUE(vectorized::VExpr::create_expr_tree(texpr,
range_search_ctx).ok());
ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok());
ASSERT_TRUE(range_search_ctx->open(state.get()).ok());
- ASSERT_TRUE(range_search_ctx->prepare_ann_range_search().ok());
+ ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok());
std::shared_ptr<VectorizedFnCall> fn_call =
std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root());
@@ -840,7 +843,7 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) {
ASSERT_EQ(fn_call->_ann_range_search_params.radius, 10);
doris::segment_v2::RangeSearchParams range_search_params =
- fn_call->_ann_range_search_params.toRangeSearchParams();
+ fn_call->_ann_range_search_params.to_range_search_params();
EXPECT_EQ(range_search_params.radius, 10.0f);
std::vector<int> query_array_groud_truth = {1, 2, 3, 4, 5, 6, 7, 20};
std::vector<int> query_array_f32;
@@ -867,11 +870,11 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) {
ASSERT_TRUE(vectorized::VExpr::create_expr_tree(texpr,
range_search_ctx).ok());
ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok());
ASSERT_TRUE(range_search_ctx->open(state.get()).ok());
- ASSERT_TRUE(range_search_ctx->prepare_ann_range_search().ok());
-
+ doris::VectorSearchUserParams user_params;
+ ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok());
std::shared_ptr<VectorizedFnCall> fn_call =
std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root());
-
+ ASSERT_EQ(fn_call->_ann_range_search_params.user_params, user_params);
ASSERT_TRUE(fn_call->_ann_range_search_params.is_ann_range_search == true);
ASSERT_EQ(fn_call->_ann_range_search_params.is_le_or_lt, false);
ASSERT_EQ(fn_call->_ann_range_search_params.src_col_idx, 1);
@@ -897,6 +900,12 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) {
dynamic_cast<doris::vector_search_utils::MockAnnIndexIterator*>(
cid_to_index_iterators[1].get());
+ std::map<std::string, std::string> properties;
+ properties["index_type"] = "hnsw";
+ properties["metric_type"] = "l2_distance";
+ auto pair = vector_search_utils::create_tmp_ann_index_reader(properties);
+ mock_ann_index_iter->_ann_reader = pair.second;
+
// Explain:
// 1. predicate is dist >= 10, so it is not a within range search
// 2. return 10 results
@@ -906,18 +915,11 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) {
}),
testing::_, testing::_))
.WillOnce(testing::Invoke([](const
doris::segment_v2::RangeSearchParams& params,
- const
doris::segment_v2::CustomSearchParams& custom_params,
+ const doris::VectorSearchUserParams&
custom_params,
doris::segment_v2::RangeSearchResult*
result) {
- // size_t num_results = 10;
result->roaring = std::make_shared<roaring::Roaring>();
result->row_ids = nullptr;
result->distance = nullptr;
- // result->row_ids = std::make_unique<std::vector<uint64_t>>();
- // for (size_t i = 0; i < num_results; ++i) {
- // result->roaring->add(i * 10);
- // result->row_ids->push_back(i * 10);
- // }
- // result->distance = std::make_unique<float[]>(10);
return Status::OK();
}));
@@ -960,7 +962,8 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) {
ASSERT_TRUE(vectorized::VExpr::create_expr_tree(texpr,
range_search_ctx).ok());
ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok());
ASSERT_TRUE(range_search_ctx->open(state.get()).ok());
- ASSERT_TRUE(range_search_ctx->prepare_ann_range_search().ok());
+ doris::VectorSearchUserParams user_params;
+ ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok());
std::shared_ptr<VectorizedFnCall> fn_call =
std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root());
@@ -984,11 +987,15 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) {
std::vector<std::unique_ptr<segment_v2::ColumnIterator>> column_iterators;
column_iterators.resize(4);
column_iterators[3] =
std::make_unique<doris::segment_v2::VirtualColumnIterator>();
-
roaring::Roaring row_bitmap;
doris::vector_search_utils::MockAnnIndexIterator* mock_ann_index_iter =
dynamic_cast<doris::vector_search_utils::MockAnnIndexIterator*>(
cid_to_index_iterators[1].get());
+ std::map<std::string, std::string> properties;
+ properties["index_type"] = "hnsw";
+ properties["metric_type"] = "l2_distance";
+ auto pair = vector_search_utils::create_tmp_ann_index_reader(properties);
+ mock_ann_index_iter->_ann_reader = pair.second;
// Explain:
// 1. predicate is dist >= 10, so it is not a within range search
@@ -999,7 +1006,7 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) {
}),
testing::_, testing::_))
.WillOnce(testing::Invoke([](const
doris::segment_v2::RangeSearchParams& params,
- const
doris::segment_v2::CustomSearchParams& custom_params,
+ const doris::VectorSearchUserParams&
custom_params,
doris::segment_v2::RangeSearchResult*
result) {
size_t num_results = 10;
result->roaring = std::make_shared<roaring::Roaring>();
diff --git a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
index 7437bd80411..e6bc46c5d97 100644
--- a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
+++ b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
@@ -65,8 +65,8 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorConstructor) {
v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root());
std::shared_ptr<AnnTopNDescriptor> predicate;
- predicate = AnnTopNDescriptor::create_shared(limit, virtual_slot_expr_ctx);
- ASSERT_TRUE(predicate != nullptr) << "AnnTopNDescriptor::create_shared()
failed";
+ predicate = AnnTopNDescriptor::create_shared(true, limit,
virtual_slot_expr_ctx);
+ ASSERT_TRUE(predicate != nullptr) <<
"AnnTopNDescriptor::create_shared(true,) failed";
}
TEST_F(VectorSearchTest, AnnTopNDescriptorPrepare) {
@@ -86,7 +86,7 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorPrepare) {
v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root());
std::shared_ptr<AnnTopNDescriptor> predicate;
- predicate = AnnTopNDescriptor::create_shared(limit, virtual_slot_expr_ctx);
+ predicate = AnnTopNDescriptor::create_shared(true, limit,
virtual_slot_expr_ctx);
st = predicate->prepare(&_runtime_state, _row_desc);
ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(),
predicate->get_order_by_expr_ctx()->root()->debug_string());
@@ -111,7 +111,7 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorEvaluateTopN) {
v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root());
std::shared_ptr<AnnTopNDescriptor> predicate;
- predicate = AnnTopNDescriptor::create_shared(limit, virtual_slot_expr_ctx);
+ predicate = AnnTopNDescriptor::create_shared(true, limit,
virtual_slot_expr_ctx);
st = predicate->prepare(&_runtime_state, _row_desc);
ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(),
predicate->get_order_by_expr_ctx()->root()->debug_string());
diff --git a/be/test/olap/vector_search/faiss_vector_index_test.cpp
b/be/test/olap/vector_search/faiss_vector_index_test.cpp
index bae49627ef2..684c3bf6cf4 100644
--- a/be/test/olap/vector_search/faiss_vector_index_test.cpp
+++ b/be/test/olap/vector_search/faiss_vector_index_test.cpp
@@ -300,13 +300,13 @@ TEST_F(VectorSearchTest, CompRangeSearch) {
size_t random_n =
std::uniform_int_distribution<>(500, 2000)(gen); // Random
number of vectors
// Step 1: Create and build index
- auto index1 = std::make_unique<FaissVectorIndex>();
+ auto doris_index = std::make_unique<FaissVectorIndex>();
FaissBuildParameter params;
params.d = random_d;
params.m = random_m;
params.index_type = FaissBuildParameter::IndexType::HNSW;
- index1->set_build_params(params);
+ doris_index->set_build_params(params);
const int num_vectors = random_n;
std::vector<std::vector<float>> vectors;
@@ -316,31 +316,20 @@ TEST_F(VectorSearchTest, CompRangeSearch) {
}
std::unique_ptr<faiss::Index> native_index =
std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m);
-
doris::vector_search_utils::add_vectors_to_indexes_serial_mode(index1.get(),
+
doris::vector_search_utils::add_vectors_to_indexes_serial_mode(doris_index.get(),
native_index.get(), vectors);
std::vector<float> query_vec = vectors.front();
-
- std::vector<std::pair<size_t, float>> distances(num_vectors);
- for (int i = 0; i < num_vectors; i++) {
- double sum = 0;
- auto& vec = vectors[i];
- for (int j = 0; j < params.d; j++) {
- accumulate(vec[j], query_vec[j], sum);
- }
- distances[i] = std::make_pair(i, finalize(sum));
- }
- std::sort(distances.begin(), distances.end(),
- [](const auto& a, const auto& b) { return a.second <
b.second; });
-
- float radius = distances[num_vectors / 4].second;
+ const float radius =
doris::vector_search_utils::get_radius_from_matrix(
+ query_vec.data(), params.d, vectors, 0.4f);
HNSWSearchParameters hnsw_params;
hnsw_params.ef_search = 16; // Set efSearch for better accuracy
hnsw_params.roaring = nullptr; // No selector for this test
hnsw_params.is_le_or_lt = true;
IndexSearchResult doris_result;
- std::ignore = index1->range_search(query_vec.data(), radius,
hnsw_params, doris_result);
+ std::ignore =
+ doris_index->range_search(query_vec.data(), radius,
hnsw_params, doris_result);
faiss::SearchParametersHNSW search_params_native;
search_params_native.efSearch = hnsw_params.ef_search;
@@ -575,6 +564,7 @@ TEST_F(VectorSearchTest, RangeSearchWithSelector1) {
true);
ASSERT_EQ(native_results.size(),
doris_search_result.roaring->cardinality());
+
ASSERT_EQ(doris_search_result.distances != nullptr, true);
for (size_t i = 0; i < native_results.size(); i++) {
const size_t rowid = native_results[i].first;
diff --git a/be/test/olap/vector_search/vector_search_utils.cpp
b/be/test/olap/vector_search/vector_search_utils.cpp
index 6ee03666bbf..49c60296a8e 100644
--- a/be/test/olap/vector_search/vector_search_utils.cpp
+++ b/be/test/olap/vector_search/vector_search_utils.cpp
@@ -29,6 +29,7 @@ namespace doris::vector_search_utils {
static void accumulate(double x, double y, double& sum) {
sum += (x - y) * (x - y);
}
+
static double finalize(double sum) {
return sqrt(sum);
}
@@ -246,4 +247,14 @@ float get_radius_from_matrix(const float* vector, int dim,
return radius;
}
+
+std::pair<std::unique_ptr<MockTabletIndex>,
std::shared_ptr<segment_v2::AnnIndexReader>>
+create_tmp_ann_index_reader(std::map<std::string, std::string> properties) {
+ auto mock_tablet_index = std::make_unique<MockTabletIndex>();
+ mock_tablet_index->_properties = properties;
+ auto mock_index_file_reader = std::make_shared<MockIndexFileReader>();
+ auto ann_reader =
std::make_shared<segment_v2::AnnIndexReader>(mock_tablet_index.get(),
+
mock_index_file_reader);
+ return std::make_pair(std::move(mock_tablet_index), ann_reader);
+}
} // namespace doris::vector_search_utils
\ No newline at end of file
diff --git a/be/test/olap/vector_search/vector_search_utils.h
b/be/test/olap/vector_search/vector_search_utils.h
index 8fd79997819..bd4b02ad0a7 100644
--- a/be/test/olap/vector_search/vector_search_utils.h
+++ b/be/test/olap/vector_search/vector_search_utils.h
@@ -27,8 +27,9 @@
#include <thrift/protocol/TDebugProtocol.h>
#include <thrift/protocol/TJSONProtocol.h>
-#include <iostream>
#include <memory>
+#include <string>
+#include <utility>
#include "common/object_pool.h"
#include "olap/rowset/segment_v2/ann_index_iterator.h"
@@ -150,42 +151,22 @@ public:
MOCK_METHOD(Status, read_from_index, (const doris::segment_v2::IndexParam&
param), (override));
MOCK_METHOD(Status, range_search,
(const segment_v2::RangeSearchParams& params,
- const segment_v2::CustomSearchParams& custom_params,
+ const VectorSearchUserParams& custom_params,
segment_v2::RangeSearchResult* result),
(override));
private:
io::IOContext _io_ctx_mock;
};
+
+class MockAnnIndexReader : public doris::segment_v2::AnnIndexReader {};
+
+std::pair<std::unique_ptr<MockTabletIndex>,
std::shared_ptr<segment_v2::AnnIndexReader>>
+create_tmp_ann_index_reader(std::map<std::string, std::string> properties);
+
} // namespace doris::vector_search_utils
namespace doris::vectorized {
-template <typename T>
-T read_from_json(const std::string& json_str) {
- auto memBufferIn =
std::make_shared<apache::thrift::transport::TMemoryBuffer>(
- reinterpret_cast<uint8_t*>(const_cast<char*>(json_str.data())),
- static_cast<uint32_t>(json_str.size()));
- auto jsonProtocolIn =
std::make_shared<apache::thrift::protocol::TJSONProtocol>(memBufferIn);
- T params;
- params.read(jsonProtocolIn.get());
- return params;
-}
-
-template <typename T>
-void write_to_json(const std::string& path, std::string name, const T& expr) {
- auto memBuffer =
std::make_shared<apache::thrift::transport::TMemoryBuffer>();
- auto jsonProtocol =
std::make_shared<apache::thrift::protocol::TJSONProtocol>(memBuffer);
-
- expr.write(jsonProtocol.get());
- uint8_t* buf;
- uint32_t size;
- memBuffer->getBuffer(&buf, &size);
- std::string file_path = fmt::format("{}/{}.json", path, name);
- std::ofstream ofs(file_path, std::ios::binary);
- ofs.write(reinterpret_cast<const char*>(buf), size);
- ofs.close();
- std::cout << fmt::format("Serialized JSON written to {}\n", file_path);
-}
class VectorSearchTest : public ::testing::Test {
public:
diff --git a/be/test/olap/vector_search/virtual_column_iterator_test.cpp
b/be/test/olap/vector_search/virtual_column_iterator_test.cpp
index 22f33db3bed..d860a7a7ad8 100644
--- a/be/test/olap/vector_search/virtual_column_iterator_test.cpp
+++ b/be/test/olap/vector_search/virtual_column_iterator_test.cpp
@@ -45,7 +45,7 @@ TEST_F(VectorSearchTest, TestDefaultConstructor) {
}
// Test with a materialized int32_t column
-TEST_F(VectorSearchTest, TestWithint32_tColumn) {
+TEST_F(VectorSearchTest, ReadByRowIdsint32_tColumn) {
VirtualColumnIterator iterator;
// Create a materialized int32_t column with values [10, 20, 30, 40, 50]
@@ -78,7 +78,7 @@ TEST_F(VectorSearchTest, TestWithint32_tColumn) {
}
// Test with a String column
-TEST_F(VectorSearchTest, TestWithStringColumn) {
+TEST_F(VectorSearchTest, ReadByRowIdsStringColumn) {
VirtualColumnIterator iterator;
// Create a materialized String column
@@ -114,7 +114,7 @@ TEST_F(VectorSearchTest, TestWithStringColumn) {
}
// Test with empty rowids array
-TEST_F(VectorSearchTest, TestEmptyRowIds) {
+TEST_F(VectorSearchTest, ReadByRowIdsEmptyRowIds) {
VirtualColumnIterator iterator;
// Create a materialized int32_t column with values [10, 20, 30, 40, 50]
@@ -180,7 +180,7 @@ TEST_F(VectorSearchTest, TestLargeRowset) {
}
}
-TEST_F(VectorSearchTest, TestNoContinueRowIds) {
+TEST_F(VectorSearchTest, ReadByRowIdsNoContinueRowIds) {
// Create a column with 1000 values (0-999)
auto column = ColumnVector<int32_t>::create();
auto labels = std::make_unique<std::vector<uint64_t>>();
@@ -276,4 +276,72 @@ TEST_F(VectorSearchTest, TestNoContinueRowIds) {
}
}
+TEST_F(VectorSearchTest, NextBatchTest1) {
+ VirtualColumnIterator iterator;
+
+ // 构造一个有100行的int32列,值为0~99
+ auto int_column = vectorized::ColumnVector<int32_t>::create();
+ auto labels = std::make_unique<std::vector<uint64_t>>();
+ for (int i = 0; i < 100; ++i) {
+ int_column->insert(i);
+ labels->push_back(i);
+ }
+ iterator.prepare_materialization(std::move(int_column), std::move(labels));
+
+ // 1. seek到第10行,next_batch读取10行
+ {
+ vectorized::MutableColumnPtr dst =
vectorized::ColumnVector<int32_t>::create();
+ Status st = iterator.seek_to_ordinal(10);
+ ASSERT_TRUE(st.ok());
+ size_t rows_read = 10;
+ bool has_null = false;
+ st = iterator.next_batch(&rows_read, dst, &has_null);
+ ASSERT_TRUE(st.ok());
+ ASSERT_EQ(rows_read, 10);
+ ASSERT_EQ(dst->size(), 10);
+ for (int i = 0; i < 10; ++i) {
+ ASSERT_EQ(dst->get_int(i), 10 + i);
+ }
+ }
+
+ // 2. seek到第85行,next_batch读取10行(只剩5行可读)
+ {
+ vectorized::MutableColumnPtr dst =
vectorized::ColumnVector<int32_t>::create();
+ Status st = iterator.seek_to_ordinal(85);
+ ASSERT_TRUE(st.ok());
+ size_t rows_read = 10;
+ bool has_null = false;
+ st = iterator.next_batch(&rows_read, dst, &has_null);
+ ASSERT_TRUE(st.ok());
+ ASSERT_EQ(rows_read, 10);
+ ASSERT_EQ(dst->size(), 10);
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(dst->get_int(i), 85 + i);
+ }
+ }
+
+ // 3. seek到第0行,next_batch读取全部100行
+ {
+ vectorized::MutableColumnPtr dst =
vectorized::ColumnVector<int32_t>::create();
+ Status st = iterator.seek_to_ordinal(0);
+ ASSERT_TRUE(st.ok());
+ size_t rows_read = 100;
+ bool has_null = false;
+ st = iterator.next_batch(&rows_read, dst, &has_null);
+ ASSERT_TRUE(st.ok());
+ ASSERT_EQ(rows_read, 100);
+ ASSERT_EQ(dst->size(), 100);
+ for (int i = 0; i < 100; ++i) {
+ ASSERT_EQ(dst->get_int(i), i);
+ }
+ }
+
+ // 4. seek到越界位置(如100),应该报错
+ {
+ vectorized::MutableColumnPtr dst =
vectorized::ColumnVector<int32_t>::create();
+ Status st = iterator.seek_to_ordinal(100);
+ ASSERT_EQ(st.ok(), false);
+ }
+}
+
} // namespace doris::vectorized
\ No newline at end of file
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java
index 536f3e79f35..1bfd71e920a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java
@@ -69,23 +69,24 @@ public class PushDownVirtualColumnsIntoOlapScan implements
RewriteRuleFactory {
// 3. replace filter
// 4. replace project
Map<Expression, Expression> replaceMap = Maps.newHashMap();
+ ImmutableList.Builder<NamedExpression> virtualColumnsBuilder =
ImmutableList.builder();
for (Expression conjunct : filter.getConjuncts()) {
- Set<Expression> l2Distances =
conjunct.collect(L2Distance.class::isInstance);
- for (Expression l2Distance : l2Distances) {
- if (replaceMap.containsKey(l2Distance)) {
+ // Set<Expression> l2Distances =
conjunct.collect(L2Distance.class::isInstance);
+ // Set<Expression> innerProducts =
conjunct.collect(InnerProduct.class::isInstance);
+ Set<Expression> distanceFunctions = conjunct.collect(
+ e -> e instanceof L2Distance || e instanceof InnerProduct);
+ for (Expression distanceFunction : distanceFunctions) {
+ if (replaceMap.containsKey(distanceFunction)) {
continue;
}
- Alias alias = new Alias(l2Distance);
- replaceMap.put(l2Distance, alias.toSlot());
+ Alias alias = new Alias(distanceFunction);
+ replaceMap.put(distanceFunction, alias.toSlot());
+ virtualColumnsBuilder.add(alias);
}
}
if (replaceMap.isEmpty()) {
return null;
}
- ImmutableList.Builder<NamedExpression> virtualColumnsBuilder =
ImmutableList.builder();
- for (Expression expression : replaceMap.values()) {
- virtualColumnsBuilder.add((NamedExpression) expression);
- }
logicalOlapScan =
logicalOlapScan.withVirtualColumns(virtualColumnsBuilder.build());
Set<Expression> conjuncts =
ExpressionUtils.replace(filter.getConjuncts(), replaceMap);
Plan plan = filter.withConjunctsAndChild(conjuncts, logicalOlapScan);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java
index bc8e533e422..1aeb0770be2 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java
@@ -136,9 +136,22 @@ public class IndexDefinition {
public void checkColumn(ColumnDefinition column, KeysType keysType,
boolean enableUniqueKeyMergeOnWrite,
TInvertedIndexFileStorageFormat invertedIndexFileStorageFormat)
throws AnalysisException {
+ if (indexType == IndexType.ANN) {
+ String indexColName = column.getName();
+ caseSensitivityCols.add(indexColName);
+ DataType colType = column.getType();
+ if (!colType.isArrayType()) {
+ throw new AnalysisException("ANN index column must be array
type, invalid index: " + name);
+ }
+ DataType itemType = ((ArrayType) colType).getItemType();
+ if (!itemType.isFloatType()) {
+ throw new AnalysisException("ANN index column item type must
be float type, invalid index: " + name);
+ }
+ return;
+ }
+
if (indexType == IndexType.BITMAP || indexType == IndexType.INVERTED
- || indexType == IndexType.BLOOMFILTER || indexType ==
IndexType.NGRAM_BF
- || indexType == IndexType.ANN) {
+ || indexType == IndexType.BLOOMFILTER || indexType ==
IndexType.NGRAM_BF) {
String indexColName = column.getName();
caseSensitivityCols.add(indexColName);
DataType colType = column.getType();
@@ -148,10 +161,6 @@ public class IndexDefinition {
+ " index. " + "invalid index: " + name);
}
- if (indexType == IndexType.ANN && !colType.isArrayType()) {
- throw new AnalysisException("Ann index column must be array
type, invalid index: " + name);
- }
-
// In inverted index format v1, each subcolumn of a variant has
its own index file, leading to high IOPS.
// when the subcolumn type changes, it may result in missing
files, causing link file failure.
// There are two cases in which the inverted index format v1 is
not supported:
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index 69a453a3b16..ad1aee3a04c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -739,6 +739,10 @@ public class SessionVariable implements Serializable,
Writable {
public static final String SQL_CONVERTOR_CONFIG = "sql_convertor_config";
+ public static final String HNSW_EF_SEARCH = "hnsw_ef_search";
+ public static final String HNSW_CHECK_RELATIVE_DISTANCE =
"hnsw_check_relative_distance";
+ public static final String HNSW_BOUNDED_QUEUE = "hnsw_bounded_queue";
+
/**
* If set false, user couldn't submit analyze SQL and FE won't allocate
any related resources.
*/
@@ -2611,6 +2615,22 @@ public class SessionVariable implements Serializable,
Writable {
return enableESParallelScroll;
}
+ @VariableMgr.VarAttr(name = HNSW_EF_SEARCH, needForward = true,
+ description = {"HNSW索引的EF搜索参数,控制搜索的精度和速度",
+ "HNSW index EF search parameter, controls the precision
and speed of the search"})
+ public int hnswEFSearch = 16;
+
+ @VariableMgr.VarAttr(name = HNSW_CHECK_RELATIVE_DISTANCE, needForward =
true,
+ description = {"是否启用相对距离检查机制,以提升HNSW搜索的准确性",
+ "Enable relative distance checking to improve HNSW
search accuracy"})
+ public boolean hnswCheckRelativeDistance = true;
+
+ @VariableMgr.VarAttr(name = HNSW_BOUNDED_QUEUE, needForward = true,
+ description = {"是否使用有界优先队列来优化HNSW的搜索性能",
+ "Whether to use a bounded priority queue to optimize
HNSW search performance"})
+ public boolean hnswBoundedQueue = true;
+
+
// If this fe is in fuzzy mode, then will use initFuzzyModeVariables to
generate some variables,
// not the default value set in the code.
@SuppressWarnings("checkstyle:Indentation")
@@ -4218,6 +4238,11 @@ public class SessionVariable implements Serializable,
Writable {
tResult.setMinimumOperatorMemoryRequiredKb(minimumOperatorMemoryRequiredKB);
tResult.setExchangeMultiBlocksByteSize(exchangeMultiBlocksByteSize);
+
+ tResult.setHnswEfSearch(hnswEFSearch);
+ tResult.setHnswCheckRelativeDistance(hnswCheckRelativeDistance);
+ tResult.setHnswBoundedQueue(hnswBoundedQueue);
+
return tResult;
}
diff --git a/gensrc/thrift/PaloInternalService.thrift
b/gensrc/thrift/PaloInternalService.thrift
index 81e4d1f877c..73c268965e8 100644
--- a/gensrc/thrift/PaloInternalService.thrift
+++ b/gensrc/thrift/PaloInternalService.thrift
@@ -395,6 +395,10 @@ struct TQueryOptions {
164: optional bool check_orc_init_sargs_success = false
165: optional i32 exchange_multi_blocks_byte_size = 262144
+ 166: optional i32 hnsw_ef_search = 16;
+ 167: optional bool hnsw_check_relative_distance = true;
+ 168: optional bool hnsw_bounded_queue = true;
+
// For cloud, to control if the content would be written into file cache
// In write path, to control if the content would be written into file cache.
// In read path, read from file cache or remote storage when execute query.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]