xishuaidelin commented on code in PR #137:
URL:
https://github.com/apache/flink-connector-elasticsearch/pull/137#discussion_r2903468997
##########
flink-connector-elasticsearch8/src/main/java/org/apache/flink/connector/elasticsearch/table/search/ElasticsearchRowDataVectorSearchFunction.java:
##########
@@ -0,0 +1,165 @@
+package org.apache.flink.connector.elasticsearch.table.search;
+
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.connector.elasticsearch.sink.NetworkConfig;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.utils.JoinedRowData;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.VectorSearchFunction;
+import org.apache.flink.util.FlinkRuntimeException;
+
+import co.elastic.clients.elasticsearch.ElasticsearchClient;
+import co.elastic.clients.elasticsearch.core.SearchRequest;
+import co.elastic.clients.elasticsearch.core.SearchResponse;
+import co.elastic.clients.elasticsearch.core.search.Hit;
+import co.elastic.clients.json.JsonData;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/** The {@link VectorSearchFunction} implementation for Elasticsearch. */
+public class ElasticsearchRowDataVectorSearchFunction extends
VectorSearchFunction {
+ private static final Logger LOG =
+
LoggerFactory.getLogger(ElasticsearchRowDataVectorSearchFunction.class);
+ private static final long serialVersionUID = 1L;
+
+ private final DeserializationSchema<RowData> deserializationSchema;
+
+ private final String index;
+
+ private final String[] producedNames;
+ private final int maxRetryTimes;
+ private final int numCandidates;
+ private final String searchColumn;
+
+ private final NetworkConfig networkConfig;
+
+ private transient ElasticsearchClient client;
+
+ public ElasticsearchRowDataVectorSearchFunction(
+ DeserializationSchema<RowData> deserializationSchema,
+ int maxRetryTimes,
+ int numCandidates,
+ String index,
+ String searchColumn,
+ String[] producedNames,
+ NetworkConfig networkConfig) {
+
+ checkNotNull(deserializationSchema, "No DeserializationSchema
supplied.");
+ checkNotNull(maxRetryTimes, "No maxRetryTimes supplied.");
+ checkNotNull(producedNames, "No fieldNames supplied.");
+ checkNotNull(networkConfig, "No networkConfig supplied.");
+
+ this.deserializationSchema = deserializationSchema;
+ this.maxRetryTimes = maxRetryTimes;
+ this.numCandidates = numCandidates;
+ this.index = index;
+ this.searchColumn = searchColumn;
+ this.producedNames = producedNames;
+ this.networkConfig = networkConfig;
+ }
+
+ @Override
+ public void open(FunctionContext context) throws Exception {
+ this.client = networkConfig.createEsSyncClient();
+
+ deserializationSchema.open(null);
+ }
+
+ @Override
+ public Collection<RowData> vectorSearch(int topK, RowData features) throws
IOException {
+ List<Float> queryVector = new ArrayList<>();
+ for (float feature : features.getArray(0).toFloatArray()) {
+ queryVector.add(feature);
+ }
+ SearchRequest.Builder builder =
+ new SearchRequest.Builder()
+ .index(index)
+ .knn(kb ->
kb.field(searchColumn).numCandidates(numCandidates).queryVector(queryVector).k(topK))
+ .source(src -> src.filter(f ->
f.includes(Arrays.asList(producedNames))));
+ SearchRequest request = builder.build();
+
+ for (int retry = 0; retry <= maxRetryTimes; retry++) {
+ try {
+ ArrayList<RowData> rows = new ArrayList<>();
+ Tuple2<String, SearchResult[]> searchResponse = search(client,
request);
+
+ if (searchResponse.f1.length > 0) {
+ for (SearchResult result : searchResponse.f1) {
+ String source = result.source;
+ RowData row = parseSearchResult(source);
+ GenericRowData scoreData = new GenericRowData(1);
+ scoreData.setField(0, result.score);
+ if (row != null) {
+ rows.add(new JoinedRowData(row, scoreData));
+ }
+ }
+ rows.trimToSize();
+ return rows;
+ }
+ } catch (IOException e) {
+ LOG.error(String.format("Elasticsearch search error, retry
times = %d", retry), e);
+ if (retry >= maxRetryTimes) {
+ throw new FlinkRuntimeException("Execution of
Elasticsearch search failed.", e);
+ }
+ try {
+ Thread.sleep(1000L * retry);
+ } catch (InterruptedException e1) {
+ LOG.warn(
+ "Interrupted while waiting to retry failed
elasticsearch search, aborting");
+ throw new FlinkRuntimeException(e1);
+ }
+ }
+ }
+ return Collections.emptyList();
+ }
+
+ private RowData parseSearchResult(String result) {
+ RowData row = null;
+ try {
+ row = deserializationSchema.deserialize(result.getBytes());
+ } catch (IOException e) {
+ LOG.error("Deserialize search hit failed: " + e.getMessage());
+ }
+
+ return row;
+ }
+
+ private Tuple2<String, SearchResult[]> search(
+ ElasticsearchClient client, SearchRequest searchRequest) throws
IOException {
+ SearchResponse<JsonData> searchResponse = client.search(searchRequest,
JsonData.class);
+ List<Hit<JsonData>> searchHits = searchResponse.hits().hits();
+
+ return new Tuple2<>(
+ searchResponse.scrollId(),
+ searchHits.stream()
+ .map(hit -> {
+ if (hit.source() != null) {
+ return new
SearchResult(hit.source().toJson().toString(), hit.score());
+ } else {
+ return new SearchResult(null, hit.score());
Review Comment:
Maybe we should filter the data with source=null to avoid NPE in method
parseSearchResult.
`searchHits.stream()
.filter(hit -> hit.source() != null)
.map(hit -> new SearchResult(hit.source().toJson().toString(),
hit.score()))
.toArray(SearchResult[]::new);`
##########
flink-connector-elasticsearch8/src/main/java/org/apache/flink/connector/elasticsearch/table/search/ElasticsearchRowDataVectorSearchFunction.java:
##########
@@ -0,0 +1,165 @@
+package org.apache.flink.connector.elasticsearch.table.search;
+
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.connector.elasticsearch.sink.NetworkConfig;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.utils.JoinedRowData;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.VectorSearchFunction;
+import org.apache.flink.util.FlinkRuntimeException;
+
+import co.elastic.clients.elasticsearch.ElasticsearchClient;
+import co.elastic.clients.elasticsearch.core.SearchRequest;
+import co.elastic.clients.elasticsearch.core.SearchResponse;
+import co.elastic.clients.elasticsearch.core.search.Hit;
+import co.elastic.clients.json.JsonData;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/** The {@link VectorSearchFunction} implementation for Elasticsearch. */
+public class ElasticsearchRowDataVectorSearchFunction extends
VectorSearchFunction {
Review Comment:
Could we introduce an AbstractElasticsearchVectorSearchFunction base class
for both the ES7 and ES8 implementations? This would allow us to extract and
reuse the shared logic.
##########
flink-connector-elasticsearch8/src/main/java/org/apache/flink/connector/elasticsearch/table/search/ElasticsearchRowDataVectorSearchFunction.java:
##########
@@ -0,0 +1,165 @@
+package org.apache.flink.connector.elasticsearch.table.search;
+
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.connector.elasticsearch.sink.NetworkConfig;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.utils.JoinedRowData;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.VectorSearchFunction;
+import org.apache.flink.util.FlinkRuntimeException;
+
+import co.elastic.clients.elasticsearch.ElasticsearchClient;
+import co.elastic.clients.elasticsearch.core.SearchRequest;
+import co.elastic.clients.elasticsearch.core.SearchResponse;
+import co.elastic.clients.elasticsearch.core.search.Hit;
+import co.elastic.clients.json.JsonData;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/** The {@link VectorSearchFunction} implementation for Elasticsearch. */
+public class ElasticsearchRowDataVectorSearchFunction extends
VectorSearchFunction {
+ private static final Logger LOG =
+
LoggerFactory.getLogger(ElasticsearchRowDataVectorSearchFunction.class);
+ private static final long serialVersionUID = 1L;
+
+ private final DeserializationSchema<RowData> deserializationSchema;
+
+ private final String index;
+
+ private final String[] producedNames;
+ private final int maxRetryTimes;
+ private final int numCandidates;
+ private final String searchColumn;
+
+ private final NetworkConfig networkConfig;
+
+ private transient ElasticsearchClient client;
+
+ public ElasticsearchRowDataVectorSearchFunction(
+ DeserializationSchema<RowData> deserializationSchema,
+ int maxRetryTimes,
+ int numCandidates,
+ String index,
+ String searchColumn,
+ String[] producedNames,
+ NetworkConfig networkConfig) {
+
+ checkNotNull(deserializationSchema, "No DeserializationSchema
supplied.");
+ checkNotNull(maxRetryTimes, "No maxRetryTimes supplied.");
Review Comment:
No need to check type int.
##########
flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/search/ElasticsearchRowDataVectorSearchFunction.java:
##########
@@ -0,0 +1,189 @@
+package org.apache.flink.connector.elasticsearch.table.search;
+
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.connector.elasticsearch.ElasticsearchApiCallBridge;
+import org.apache.flink.connector.elasticsearch.NetworkClientConfig;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.utils.JoinedRowData;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.VectorSearchFunction;
+import org.apache.flink.util.FlinkRuntimeException;
+
+import org.apache.http.HttpHost;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.client.RequestOptions;
+import org.elasticsearch.client.RestHighLevelClient;
+import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptType;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Stream;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/** The {@link VectorSearchFunction} implementation for Elasticsearch. */
+public class ElasticsearchRowDataVectorSearchFunction extends
VectorSearchFunction {
+ private static final Logger LOG =
+
LoggerFactory.getLogger(ElasticsearchRowDataVectorSearchFunction.class);
+ private static final long serialVersionUID = 1L;
+ private static final String QUERY_VECTOR = "query_vector";
+
+ private final DeserializationSchema<RowData> deserializationSchema;
+
+ private final String index;
+
+ private final String[] producedNames;
+ private final int maxRetryTimes;
+ private final SearchMetric searchMetric;
+ private SearchRequest searchRequest;
+ private SearchSourceBuilder searchSourceBuilder;
+
+ private final ElasticsearchApiCallBridge<RestHighLevelClient> callBridge;
+ private final NetworkClientConfig networkClientConfig;
+ private final List<HttpHost> hosts;
+ private final String scriptScore;
+
+ private transient RestHighLevelClient client;
+
+ public ElasticsearchRowDataVectorSearchFunction(
+ DeserializationSchema<RowData> deserializationSchema,
+ int maxRetryTimes,
+ SearchMetric searchMetric,
+ String index,
+ String searchColumn,
+ String[] producedNames,
+ List<HttpHost> hosts,
+ NetworkClientConfig networkClientConfig,
+ ElasticsearchApiCallBridge<RestHighLevelClient> callBridge) {
+
+ checkNotNull(deserializationSchema, "No DeserializationSchema
supplied.");
+ checkNotNull(maxRetryTimes, "No maxRetryTimes supplied.");
Review Comment:
No need to check type 'int'.
##########
flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/search/ElasticsearchRowDataVectorSearchFunction.java:
##########
@@ -0,0 +1,189 @@
+package org.apache.flink.connector.elasticsearch.table.search;
+
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.connector.elasticsearch.ElasticsearchApiCallBridge;
+import org.apache.flink.connector.elasticsearch.NetworkClientConfig;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.utils.JoinedRowData;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.VectorSearchFunction;
+import org.apache.flink.util.FlinkRuntimeException;
+
+import org.apache.http.HttpHost;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.client.RequestOptions;
+import org.elasticsearch.client.RestHighLevelClient;
+import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptType;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Stream;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/** The {@link VectorSearchFunction} implementation for Elasticsearch. */
+public class ElasticsearchRowDataVectorSearchFunction extends
VectorSearchFunction {
+ private static final Logger LOG =
+
LoggerFactory.getLogger(ElasticsearchRowDataVectorSearchFunction.class);
+ private static final long serialVersionUID = 1L;
+ private static final String QUERY_VECTOR = "query_vector";
+
+ private final DeserializationSchema<RowData> deserializationSchema;
+
+ private final String index;
+
+ private final String[] producedNames;
+ private final int maxRetryTimes;
+ private final SearchMetric searchMetric;
+ private SearchRequest searchRequest;
+ private SearchSourceBuilder searchSourceBuilder;
+
+ private final ElasticsearchApiCallBridge<RestHighLevelClient> callBridge;
+ private final NetworkClientConfig networkClientConfig;
Review Comment:
This function in ES8 uses NetworkConfig here. Are there any differences that
require us to use a different config?
##########
flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7DynamicSource.java:
##########
@@ -0,0 +1,111 @@
+package org.apache.flink.connector.elasticsearch.table;
+
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.connector.elasticsearch.ElasticsearchApiCallBridge;
+import org.apache.flink.connector.elasticsearch.NetworkClientConfig;
+import
org.apache.flink.connector.elasticsearch.table.search.ElasticsearchRowDataVectorSearchFunction;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.connector.format.DecodingFormat;
+import org.apache.flink.table.connector.source.DynamicTableSource;
+import org.apache.flink.table.connector.source.VectorSearchTableSource;
+import org.apache.flink.table.connector.source.lookup.cache.LookupCache;
+import
org.apache.flink.table.connector.source.search.VectorSearchFunctionProvider;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.ArrayType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.RowType;
+
+import org.elasticsearch.client.RestHighLevelClient;
+
+import javax.annotation.Nullable;
+
+/**
+ * A {@link DynamicTableSource} that describes how to create a {@link
Elasticsearch7DynamicSource}
+ * from a logical description.
+ */
+public class Elasticsearch7DynamicSource extends ElasticsearchDynamicSource
+ implements VectorSearchTableSource {
+
+ public Elasticsearch7DynamicSource(
+ DecodingFormat<DeserializationSchema<RowData>> format,
+ ElasticsearchConfiguration config,
+ DataType physicalRowDataType,
+ int maxRetryTimes,
+ String summaryString,
+ ElasticsearchApiCallBridge<RestHighLevelClient> apiCallBridge,
+ @Nullable LookupCache lookupCache,
+ @Nullable String docType) {
+ super(
+ format,
+ config,
+ physicalRowDataType,
+ maxRetryTimes,
+ summaryString,
+ apiCallBridge,
+ lookupCache,
+ docType);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public VectorSearchRuntimeProvider getSearchRuntimeProvider(
+ VectorSearchContext vectorSearchContext) {
+
+ NetworkClientConfig networkClientConfig = buildNetworkClientConfig();
+
+ ElasticsearchRowDataVectorSearchFunction vectorSearchFunction =
+ new ElasticsearchRowDataVectorSearchFunction(
+ this.format.createRuntimeDecoder(vectorSearchContext,
physicalRowDataType),
+ this.maxRetryTimes,
+ ((Elasticsearch7Configuration)
config).getVectorSearchMetric(),
+ config.getIndex(),
+ getSearchColumn(vectorSearchContext),
+
DataType.getFieldNames(physicalRowDataType).toArray(new String[0]),
+ config.getHosts(),
+ networkClientConfig,
+ (ElasticsearchApiCallBridge<RestHighLevelClient>)
apiCallBridge);
+
+ return VectorSearchFunctionProvider.of(vectorSearchFunction);
+ }
+
+ private String getSearchColumn(VectorSearchContext vectorSearchContext) {
Review Comment:
This function appears to be identical to the one in ES8. Would it make sense
to extract a base class to avoid duplication?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]