linrrzqqq commented on code in PR #60412: URL: https://github.com/apache/doris/pull/60412#discussion_r2935852440
########## be/src/exprs/function/function_levenshtein.cpp: ########## @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> +#include <cstring> +#include <vector> + +#include "common/status.h" +#include "core/data_type/data_type_number.h" +#include "core/string_ref.h" +#include "exprs/function/function_totype.h" +#include "exprs/function/simple_function_factory.h" +#include "util/simd/vstring_function.h" + +namespace doris { +#include "common/compile_check_begin.h" + +struct NameLevenshtein { + static constexpr auto name = "levenshtein"; +}; + +template <typename LeftDataType, typename RightDataType> +struct LevenshteinImpl { + using ResultDataType = DataTypeInt32; + using ResultPaddedPODArray = PaddedPODArray<Int32>; + + static Status vector_vector(const ColumnString::Chars& ldata, + const ColumnString::Offsets& loffsets, + const ColumnString::Chars& rdata, + const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { + DCHECK_EQ(loffsets.size(), roffsets.size()); + + const size_t size = loffsets.size(); + res.resize(size); + for (size_t i = 0; i < size; ++i) { + res[i] = levenshtein_distance(string_ref_at(ldata, loffsets, i), + string_ref_at(rdata, roffsets, i)); + } + return Status::OK(); + } + + static Status vector_scalar(const ColumnString::Chars& ldata, + const ColumnString::Offsets& loffsets, const StringRef& rdata, + ResultPaddedPODArray& res) { + const size_t size = loffsets.size(); + res.resize(size); + for (size_t i = 0; i < size; ++i) { + res[i] = levenshtein_distance(string_ref_at(ldata, loffsets, i), rdata); + } + return Status::OK(); + } + + static Status scalar_vector(const StringRef& ldata, const ColumnString::Chars& rdata, + const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { + const size_t size = roffsets.size(); + res.resize(size); + for (size_t i = 0; i < size; ++i) { + res[i] = levenshtein_distance(ldata, string_ref_at(rdata, roffsets, i)); + } + return Status::OK(); + } + +private: + static StringRef string_ref_at(const ColumnString::Chars& data, + const ColumnString::Offsets& offsets, size_t i) { + DCHECK_LT(i, offsets.size()); + const size_t begin = (i == 0) ? 0 : offsets[i - 1]; + const size_t end = offsets[i]; + if (end <= begin || end > data.size()) { + return StringRef("", 0); + } + + size_t str_size = end - begin; + if (str_size > 0 && data[end - 1] == '\0') { + --str_size; + } + return StringRef(reinterpret_cast<const char*>(data.data() + begin), str_size); + } + + static void utf8_char_offsets(const StringRef& ref, std::vector<size_t>& offsets) { + offsets.clear(); + offsets.reserve(ref.size); + simd::VStringFunctions::get_char_len(ref.data, ref.size, offsets); + } + + static bool utf8_char_equal(const StringRef& left, size_t left_off, size_t left_next, + const StringRef& right, size_t right_off, size_t right_next) { + const size_t left_len = left_next - left_off; + const size_t right_len = right_next - right_off; + return left_len == right_len && + std::memcmp(left.data + left_off, right.data + right_off, left_len) == 0; + } + + static Int32 levenshtein_distance_ascii(const StringRef& left, const StringRef& right) { + const StringRef* left_ref = &left; + const StringRef* right_ref = &right; + size_t m = left.size; + size_t n = right.size; + + if (n > m) { + std::swap(left_ref, right_ref); + std::swap(m, n); + } + + std::vector<Int32> prev(n + 1); + std::vector<Int32> curr(n + 1); + for (size_t j = 0; j <= n; ++j) { + prev[j] = static_cast<Int32>(j); + } + + for (size_t i = 1; i <= m; ++i) { + curr[0] = static_cast<Int32>(i); + const char left_char = left_ref->data[i - 1]; + + for (size_t j = 1; j <= n; ++j) { + const Int32 cost = left_char == right_ref->data[j - 1] ? 0 : 1; + const Int32 insert_cost = curr[j - 1] + 1; + const Int32 delete_cost = prev[j] + 1; + const Int32 replace_cost = prev[j - 1] + cost; + curr[j] = std::min({insert_cost, delete_cost, replace_cost}); Review Comment: use min(min(a, b), c) to avoid temporary object construction -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
