This is an automated email from the ASF dual-hosted git repository. hope pushed a commit to branch release-1.4 in repository https://gitbox.apache.org/repos/asf/paimon.git
commit 53e674ed481e7154c98d23a1ecda79e9a53e78b0 Author: Faiz <[email protected]> AuthorDate: Mon Mar 30 20:00:01 2026 +0800 [flink] support vector search procedure for flink (#7550) --- docs/content/append-table/global-index.md | 26 +++ docs/content/flink/procedures.md | 34 ++++ .../flink/procedure/VectorSearchProcedure.java | 216 +++++++++++++++++++++ .../services/org.apache.paimon.factories.Factory | 3 +- .../procedure/VectorSearchProcedureITCase.java | 216 +++++++++++++++++++++ 5 files changed, 494 insertions(+), 1 deletion(-) diff --git a/docs/content/append-table/global-index.md b/docs/content/append-table/global-index.md index dc90a4664f..f00462492f 100644 --- a/docs/content/append-table/global-index.md +++ b/docs/content/append-table/global-index.md @@ -109,6 +109,32 @@ SELECT * FROM vector_search('my_table', 'embedding', array(1.0f, 2.0f, 3.0f), 5) ``` {{< /tab >}} +{{< tab "Flink SQL (Procedure)" >}} + +Unlike Spark's table-valued function, Flink uses a `CALL` procedure to perform vector search. +The procedure returns JSON-serialized rows as strings. + +```sql +-- Search for top-5 nearest neighbors +CALL sys.vector_search( + `table` => 'db.my_table', + vector_column => 'embedding', + query_vector => '1.0,2.0,3.0', + top_k => 5 +); + +-- With projection (only return specific columns) +CALL sys.vector_search( + `table` => 'db.my_table', + vector_column => 'embedding', + query_vector => '1.0,2.0,3.0', + top_k => 5, + projection => 'id,name' +); +``` + +{{< /tab >}} + {{< tab "Java API" >}} ```java Table table = catalog.getTable(identifier); diff --git a/docs/content/flink/procedures.md b/docs/content/flink/procedures.md index 49e65d705b..1c982c3474 100644 --- a/docs/content/flink/procedures.md +++ b/docs/content/flink/procedures.md @@ -951,5 +951,39 @@ All available procedures are listed below. CALL sys.drop_function(`function` => 'function_identifier')<br/> </td> </tr> + <tr> + <td>vector_search</td> + <td> + CALL [catalog.]sys.vector_search(<br/> + `table` => 'identifier',<br/> + vector_column => 'columnName',<br/> + query_vector => 'v1,v2,...',<br/> + top_k => topK,<br/> + projection => 'col1,col2',<br/> + options => 'key1=value1;key2=value2')<br/> + </td> + <td> + To perform vector similarity search on a table with a global vector index. Returns JSON-serialized rows. Arguments: + <li>table(required): the target table identifier.</li> + <li>vector_column(required): the name of the vector column to search.</li> + <li>query_vector(required): comma-separated float values representing the query vector, e.g. '1.0,2.0,3.0'.</li> + <li>top_k(required): the number of nearest neighbors to return.</li> + <li>projection(optional): comma-separated column names to include in the result. If omitted, all columns are returned.</li> + <li>options(optional): additional dynamic options of the table.</li> + </td> + <td> + CALL sys.vector_search(<br/> + `table` => 'default.T',<br/> + vector_column => 'embedding',<br/> + query_vector => '1.0,2.0,3.0',<br/> + top_k => 5)<br/><br/> + CALL sys.vector_search(<br/> + `table` => 'default.T',<br/> + vector_column => 'embedding',<br/> + query_vector => '1.0,2.0,3.0',<br/> + top_k => 5,<br/> + projection => 'id,name') + </td> + </tr> </tbody> </table> diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/procedure/VectorSearchProcedure.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/procedure/VectorSearchProcedure.java new file mode 100644 index 0000000000..2b5df413f8 --- /dev/null +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/procedure/VectorSearchProcedure.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.flink.procedure; + +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.format.json.JsonFormatWriter; +import org.apache.paimon.format.json.JsonOptions; +import org.apache.paimon.fs.PositionOutputStream; +import org.apache.paimon.globalindex.GlobalIndexResult; +import org.apache.paimon.options.Options; +import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.table.Table; +import org.apache.paimon.table.source.ReadBuilder; +import org.apache.paimon.table.source.TableScan; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.StringUtils; + +import org.apache.flink.table.annotation.ArgumentHint; +import org.apache.flink.table.annotation.DataTypeHint; +import org.apache.flink.table.annotation.ProcedureHint; +import org.apache.flink.table.procedure.ProcedureContext; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Vector search procedure. This procedure takes one vector and searches for topK nearest vectors. + * Usage: + * + * <pre><code> + * CALL sys.vector_search( + * `table` => 'tableId', + * vector_column => 'v', + * query_vector => '1.0,2.0,3.0', + * top_k => 5 + * ) + * + * -- with projection and options + * CALL sys.vector_search( + * `table` => 'tableId', + * vector_column => 'v', + * query_vector => '1.0,2.0,3.0', + * top_k => 5, + * projection => 'id,name', + * options => 'k1=v1;k2=v2' + * ) + * </code></pre> + */ +public class VectorSearchProcedure extends ProcedureBase { + + public static final String IDENTIFIER = "vector_search"; + + @ProcedureHint( + argument = { + @ArgumentHint(name = "table", type = @DataTypeHint("STRING")), + @ArgumentHint(name = "vector_column", type = @DataTypeHint("STRING")), + @ArgumentHint(name = "query_vector", type = @DataTypeHint("STRING")), + @ArgumentHint(name = "top_k", type = @DataTypeHint("INT")), + @ArgumentHint( + name = "projection", + type = @DataTypeHint("STRING"), + isOptional = true), + @ArgumentHint(name = "options", type = @DataTypeHint("STRING"), isOptional = true) + }) + public String[] call( + ProcedureContext procedureContext, + String tableId, + String vectorColumn, + String queryVectorStr, + Integer topK, + String projection, + String options) + throws Exception { + Table table = table(tableId); + + Map<String, String> optionsMap = optionalConfigMap(options); + if (!optionsMap.isEmpty()) { + table = table.copy(optionsMap); + } + + float[] queryVector = parseVector(queryVectorStr); + + GlobalIndexResult result = + table.newVectorSearchBuilder() + .withVector(queryVector) + .withVectorColumn(vectorColumn) + .withLimit(topK) + .executeLocal(); + + RowType tableRowType = table.rowType(); + int[] projectionIndices = parseProjection(projection, tableRowType); + + ReadBuilder readBuilder = table.newReadBuilder(); + if (projectionIndices != null) { + readBuilder.withProjection(projectionIndices); + } + + TableScan.Plan plan = readBuilder.newScan().withGlobalIndexResult(result).plan(); + + RowType readType = + projectionIndices != null ? tableRowType.project(projectionIndices) : tableRowType; + + ByteArrayOutputStream byteOut = new ByteArrayOutputStream(1024); + JsonOptions jsonOptions = new JsonOptions(new Options()); + try (JsonFormatWriter jsonWriter = + new JsonFormatWriter( + new ByteArrayPositionOutputStream(byteOut), + readType, + jsonOptions, + "none"); + RecordReader<InternalRow> reader = readBuilder.newRead().createReader(plan)) { + reader.forEachRemaining( + row -> { + try { + jsonWriter.addElement(row); + } catch (Exception e) { + throw new RuntimeException("Failed to convert row to JSON string", e); + } + }); + } + + String[] lines = + StringUtils.split(byteOut.toString("UTF-8"), jsonOptions.getLineDelimiter()); + List<String> rows = new ArrayList<>(lines.length); + for (String line : lines) { + String trimmed = line.trim(); + if (!trimmed.isEmpty()) { + rows.add(trimmed); + } + } + return rows.toArray(new String[0]); + } + + private static float[] parseVector(String vectorStr) { + String[] parts = StringUtils.split(vectorStr, ","); + float[] vector = new float[parts.length]; + for (int i = 0; i < parts.length; i++) { + vector[i] = Float.parseFloat(parts[i].trim()); + } + return vector; + } + + private static int[] parseProjection(String projection, RowType rowType) { + if (StringUtils.isNullOrWhitespaceOnly(projection)) { + return null; + } + String[] projectionNames = StringUtils.split(projection, ","); + return rowType.getFieldIndices(Arrays.stream(projectionNames).collect(Collectors.toList())); + } + + @Override + public String identifier() { + return IDENTIFIER; + } + + /** A {@link PositionOutputStream} wrapping a {@link ByteArrayOutputStream}. */ + private static class ByteArrayPositionOutputStream extends PositionOutputStream { + + private final ByteArrayOutputStream out; + + private ByteArrayPositionOutputStream(ByteArrayOutputStream out) { + this.out = out; + } + + @Override + public long getPos() { + return out.size(); + } + + @Override + public void write(int b) { + out.write(b); + } + + @Override + public void write(byte[] b) throws IOException { + out.write(b); + } + + @Override + public void write(byte[] b, int off, int len) { + out.write(b, off, len); + } + + @Override + public void flush() throws IOException { + out.flush(); + } + + @Override + public void close() throws IOException { + out.close(); + } + } +} diff --git a/paimon-flink/paimon-flink-common/src/main/resources/META-INF/services/org.apache.paimon.factories.Factory b/paimon-flink/paimon-flink-common/src/main/resources/META-INF/services/org.apache.paimon.factories.Factory index f87a250ec2..d687fe2244 100644 --- a/paimon-flink/paimon-flink-common/src/main/resources/META-INF/services/org.apache.paimon.factories.Factory +++ b/paimon-flink/paimon-flink-common/src/main/resources/META-INF/services/org.apache.paimon.factories.Factory @@ -100,4 +100,5 @@ org.apache.paimon.flink.procedure.AlterColumnDefaultValueProcedure org.apache.paimon.flink.procedure.TriggerTagAutomaticCreationProcedure org.apache.paimon.flink.procedure.RemoveUnexistingManifestsProcedure org.apache.paimon.flink.procedure.DataEvolutionMergeIntoProcedure -org.apache.paimon.flink.procedure.CreateGlobalIndexProcedure \ No newline at end of file +org.apache.paimon.flink.procedure.CreateGlobalIndexProcedure +org.apache.paimon.flink.procedure.VectorSearchProcedure \ No newline at end of file diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/procedure/VectorSearchProcedureITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/procedure/VectorSearchProcedureITCase.java new file mode 100644 index 0000000000..935da25330 --- /dev/null +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/procedure/VectorSearchProcedureITCase.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.flink.procedure; + +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.GenericArray; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.flink.CatalogITCaseBase; +import org.apache.paimon.globalindex.GlobalIndexBuilderUtils; +import org.apache.paimon.globalindex.GlobalIndexSingletonWriter; +import org.apache.paimon.globalindex.ResultEntry; +import org.apache.paimon.globalindex.testvector.TestVectorGlobalIndexerFactory; +import org.apache.paimon.index.IndexFileMeta; +import org.apache.paimon.io.CompactIncrement; +import org.apache.paimon.io.DataIncrement; +import org.apache.paimon.options.Options; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.table.sink.BatchTableCommit; +import org.apache.paimon.table.sink.BatchTableWrite; +import org.apache.paimon.table.sink.BatchWriteBuilder; +import org.apache.paimon.table.sink.CommitMessage; +import org.apache.paimon.table.sink.CommitMessageImpl; +import org.apache.paimon.types.DataField; +import org.apache.paimon.utils.Range; + +import org.apache.flink.types.Row; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** IT cases for {@link VectorSearchProcedure}. */ +public class VectorSearchProcedureITCase extends CatalogITCaseBase { + + private static final String VECTOR_FIELD = "vec"; + private static final int DIMENSION = 2; + + @Test + public void testVectorSearchBasic() throws Exception { + createVectorTable("T"); + FileStoreTable table = paimonTable("T"); + + float[][] vectors = { + {1.0f, 0.0f}, // row 0 + {0.95f, 0.1f}, // row 1 + {0.1f, 0.95f}, // row 2 + {0.98f, 0.05f}, // row 3 + {0.0f, 1.0f}, // row 4 + {0.05f, 0.98f} // row 5 + }; + + writeVectors(table, vectors); + buildAndCommitVectorIndex(table, vectors); + + // Search for vectors close to (1.0, 0.0) + List<Row> result = + sql( + "CALL sys.vector_search(" + + "`table` => 'default.T', " + + "vector_column => 'vec', " + + "query_vector => '1.0 ,0.0', " + + "top_k => 3)"); + + assertThat(result).isNotEmpty(); + assertThat(result.size()).isLessThanOrEqualTo(3); + + // Verify results contain JSON strings + for (Row row : result) { + String json = row.getField(0).toString(); + assertThat(json).contains("\"id\""); + assertThat(json).contains("\"vec\""); + } + } + + @Test + public void testVectorSearchWithProjection() throws Exception { + createVectorTable("T2"); + FileStoreTable table = paimonTable("T2"); + + float[][] vectors = { + {1.0f, 0.0f}, // row 0 + {0.0f, 1.0f}, // row 1 + }; + + writeVectors(table, vectors); + buildAndCommitVectorIndex(table, vectors); + + List<Row> result = + sql( + "CALL sys.vector_search(" + + "`table` => 'default.T2', " + + "vector_column => 'vec', " + + "query_vector => '1.0,0.0', " + + "top_k => 2, " + + "projection => 'id')"); + + assertThat(result).isNotEmpty(); + assertThat(result.size()).isLessThanOrEqualTo(2); + + for (Row row : result) { + String json = row.getField(0).toString(); + assertThat(json).contains("\"id\""); + // projection only selects 'id', so 'vec' should not appear + assertThat(json).doesNotContain("\"vec\""); + } + } + + @Test + public void testVectorSearchTopK() throws Exception { + createVectorTable("T3"); + FileStoreTable table = paimonTable("T3"); + + float[][] vectors = new float[10][]; + for (int i = 0; i < 10; i++) { + vectors[i] = new float[] {(float) Math.cos(i * 0.3), (float) Math.sin(i * 0.3)}; + } + + writeVectors(table, vectors); + buildAndCommitVectorIndex(table, vectors); + + List<Row> result = + sql( + "CALL sys.vector_search(" + + "`table` => 'default.T3', " + + "vector_column => 'vec', " + + "query_vector => '1.0,0.0', " + + "top_k => 3)"); + + assertThat(result.size()).isLessThanOrEqualTo(3); + } + + private void createVectorTable(String tableName) { + sql( + "CREATE TABLE %s (" + + "id INT, " + + "vec ARRAY<FLOAT>" + + ") WITH (" + + "'bucket' = '-1', " + + "'row-tracking.enabled' = 'true', " + + "'data-evolution.enabled' = 'true', " + + "'test.vector.dimension' = '%d', " + + "'test.vector.metric' = 'l2'" + + ")", + tableName, DIMENSION); + } + + private void writeVectors(FileStoreTable table, float[][] vectors) throws Exception { + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = writeBuilder.newWrite(); + BatchTableCommit commit = writeBuilder.newCommit()) { + for (int i = 0; i < vectors.length; i++) { + write.write(GenericRow.of(i, new GenericArray(vectors[i]))); + } + commit.commit(write.prepareCommit()); + } + } + + private void buildAndCommitVectorIndex(FileStoreTable table, float[][] vectors) + throws Exception { + Options options = table.coreOptions().toConfiguration(); + DataField vectorField = table.rowType().getField(VECTOR_FIELD); + + GlobalIndexSingletonWriter writer = + (GlobalIndexSingletonWriter) + GlobalIndexBuilderUtils.createIndexWriter( + table, + TestVectorGlobalIndexerFactory.IDENTIFIER, + vectorField, + options); + for (float[] vec : vectors) { + writer.write(vec); + } + List<ResultEntry> entries = writer.finish(); + + Range rowRange = new Range(0, vectors.length - 1); + List<IndexFileMeta> indexFiles = + GlobalIndexBuilderUtils.toIndexFileMetas( + table.fileIO(), + table.store().pathFactory().globalIndexFileFactory(), + table.coreOptions(), + rowRange, + vectorField.id(), + TestVectorGlobalIndexerFactory.IDENTIFIER, + entries); + + DataIncrement dataIncrement = DataIncrement.indexIncrement(indexFiles); + CommitMessage message = + new CommitMessageImpl( + BinaryRow.EMPTY_ROW, + 0, + null, + dataIncrement, + CompactIncrement.emptyIncrement()); + try (BatchTableCommit commit = table.newBatchWriteBuilder().newCommit()) { + commit.commit(Collections.singletonList(message)); + } + } +}
