This is an automated email from the ASF dual-hosted git repository. spmallette pushed a commit to branch TINKERPOP-3158 in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 00fb9a57e5fc5920efa9cde0fc15ad74e111d354 Author: Stephen Mallette <[email protected]> AuthorDate: Fri May 2 16:13:47 2025 -0400 return distances with element --- .../services/TinkerTextSearchFactory.java | 2 +- .../services/TinkerVectorSearchFactory.java | 110 +++++++++++ .../tinkergraph/structure/AbstractTinkerGraph.java | 71 ++++++- .../structure/AbstractTinkerVectorIndex.java | 31 ++- .../gremlin/tinkergraph/structure/TinkerGraph.java | 1 + .../tinkergraph/structure/TinkerIndexElement.java | 47 +++++ .../tinkergraph/structure/TinkerIndexHelper.java | 8 +- .../structure/TinkerTransactionVectorIndex.java | 17 +- .../tinkergraph/structure/TinkerVectorIndex.java | 22 ++- .../structure/TinkerGraphVectorIndexTest.java | 208 +++++++++++++++++++-- 10 files changed, 471 insertions(+), 46 deletions(-) diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerTextSearchFactory.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerTextSearchFactory.java index b9675856e6..cdb1582493 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerTextSearchFactory.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerTextSearchFactory.java @@ -41,7 +41,7 @@ import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap; */ public class TinkerTextSearchFactory<I, R> extends TinkerServiceRegistry.TinkerServiceFactory<I, R> implements Service<I, R> { - public static final String NAME = "tinker.search"; + public static final String NAME = "tinker.search.text"; public interface Params { /** diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java new file mode 100644 index 0000000000..fb0b3c76f4 --- /dev/null +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.tinkergraph.services; + +import org.apache.tinkerpop.gremlin.process.traversal.Traverser; +import org.apache.tinkerpop.gremlin.structure.Direction; +import org.apache.tinkerpop.gremlin.structure.Edge; +import org.apache.tinkerpop.gremlin.structure.Element; +import org.apache.tinkerpop.gremlin.structure.Property; +import org.apache.tinkerpop.gremlin.structure.Vertex; +import org.apache.tinkerpop.gremlin.structure.VertexProperty; +import org.apache.tinkerpop.gremlin.structure.service.Service; +import org.apache.tinkerpop.gremlin.structure.util.CloseableIterator; +import org.apache.tinkerpop.gremlin.tinkergraph.structure.AbstractTinkerGraph; +import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerHelper; +import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerIndexElement; +import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.LongStream; + +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.KEY; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.TOP_K; +import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap; + +/** + * + */ +public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServiceFactory<Element, Map<String, Object>> implements Service<Element, Map<String, Object>> { + + public static final String NAME = "tinker.search.vector.topKByVertex"; + + public interface Params { + /** + * Specify the key storing the embedding + */ + String KEY = "key"; + /** + * Number of results to return + */ + String TOP_K = "topK"; + + Map DESCRIBE = asMap( + KEY, "Specify they key storing the embedding for the vector search", + TOP_K, "Number of results to return (optional, defaults to 10)" + ); + } + + public TinkerVectorSearchFactory(final AbstractTinkerGraph graph) { + super(graph, NAME); + } + + @Override + public Type getType() { + return Type.Streaming; + } + + @Override + public Map describeParams() { + return Params.DESCRIBE; + } + + @Override + public Set<Type> getSupportedTypes() { + return Collections.singleton(Type.Streaming); + } + + @Override + public Service<Element, Map<String, Object>> createService(final boolean isStart, final Map params) { + if (isStart) { + throw new UnsupportedOperationException(Exceptions.cannotStartTraversal); + } + return this; + } + + @Override + public CloseableIterator<Map<String,Object>> execute(final ServiceCallContext ctx, final Traverser.Admin<Element> in, final Map params) { + final String key = (String) params.get(KEY); + final int k = (int) params.getOrDefault(TOP_K, 10); + final Element e = in.get(); + final float[] embedding = e.value(key); + return CloseableIterator.of(graph.findNearestVertices(key, embedding, k).stream().map(TinkerIndexElement::toMap).iterator()); + } + + @Override + public void close() {} + +} + diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java index 5ace0fd93a..0989e172d3 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java @@ -45,7 +45,6 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Stream; @@ -304,10 +303,37 @@ public abstract class AbstractTinkerGraph implements Graph { * @param k the number of nearest neighbors to return * @return a list of vertices sorted by distance */ - public List<Vertex> findNearestVertices(final String key, final float[] vector, final int k) { + public List<TinkerIndexElement<TinkerVertex>> findNearestVertices(final String key, final float[] vector, final int k) { if (null == this.vertexVectorIndex) throw new IllegalStateException("Vector index not created for vertices on key: '" + key + "'"); - return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector, k)); + return this.vertexVectorIndex.findNearest(key, vector, k); + } + + /** + * Find the nearest vertices to the given vector in the vector index for the specified property key. + * + * @param key the property key + * @param vector the query vector + * @return a list of vertices sorted by distance + */ + public List<TinkerIndexElement<TinkerVertex>> findNearestVertices(final String key, final float[] vector) { + if (null == this.vertexVectorIndex) + throw new IllegalStateException("Vector index not created for vertices on key: '" + key + "'"); + return this.vertexVectorIndex.findNearest(key, vector); + } + + /** + * Find the nearest vertices to the given vector in the vector index for the specified property key. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of vertices sorted by distance + */ + public List<TinkerVertex> findNearestVerticesOnly(final String key, final float[] vector, final int k) { + if (null == this.vertexVectorIndex) + throw new IllegalStateException("Vector index not created for vertices on key: '" + key + "'"); + return this.vertexVectorIndex.findNearestElements(key, vector, k); } /** @@ -318,10 +344,37 @@ public abstract class AbstractTinkerGraph implements Graph { * @param vector the query vector * @return a list of vertices sorted by distance */ - public List<Vertex> findNearestVertices(final String key, final float[] vector) { + public List<TinkerVertex> findNearestVerticesOnly(final String key, final float[] vector) { if (null == this.vertexVectorIndex) throw new IllegalStateException("Vector index not created for vertices on key: '" + key + "'"); - return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector)); + return this.vertexVectorIndex.findNearestElements(key, vector); + } + + /** + * Find the nearest edges to the given vector in the vector index for the specified property key. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of vertices sorted by distance + */ + public List<TinkerIndexElement<TinkerEdge>> findNearestEdges(final String key, final float[] vector, final int k) { + if (null == this.edgeVectorIndex) + throw new IllegalStateException("Vector index not created for edges on key: '" + key + "'"); + return this.edgeVectorIndex.findNearest(key, vector, k); + } + + /** + * Find the nearest edges to the given vector in the vector index for the specified property key. + * + * @param key the property key + * @param vector the query vector + * @return a list of vertices sorted by distance + */ + public List<TinkerIndexElement<TinkerEdge>> findNearestEdges(final String key, final float[] vector) { + if (null == this.edgeVectorIndex) + throw new IllegalStateException("Vector index not created for edges on key: '" + key + "'"); + return this.edgeVectorIndex.findNearest(key, vector); } /** @@ -332,10 +385,10 @@ public abstract class AbstractTinkerGraph implements Graph { * @param k the number of nearest neighbors to return * @return a list of edges sorted by distance */ - public List<Edge> findNearestEdges(final String key, final float[] vector, final int k) { + public List<TinkerEdge> findNearestEdgesOnly(final String key, final float[] vector, final int k) { if (null == this.edgeVectorIndex) throw new IllegalStateException("Vector index not created for edges on key: '" + key + "'"); - return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector, k)); + return new ArrayList<>(this.edgeVectorIndex.findNearestElements(key, vector, k)); } /** @@ -346,10 +399,10 @@ public abstract class AbstractTinkerGraph implements Graph { * @param vector the query vector * @return a list of edges sorted by distance */ - public List<Edge> findNearestEdges(final String key, final float[] vector) { + public List<TinkerEdge> findNearestEdgesOnly(final String key, final float[] vector) { if (null == this.edgeVectorIndex) throw new IllegalStateException("Vector index not created for edges on key: '" + key + "'"); - return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector)); + return new ArrayList<>(this.edgeVectorIndex.findNearestElements(key, vector)); } ///////////// Utility methods /////////////// diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java index 3a4e67cb25..7826d402e8 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java @@ -21,12 +21,17 @@ package org.apache.tinkerpop.gremlin.tinkergraph.structure; import org.apache.tinkerpop.gremlin.structure.Element; import java.util.List; + /** * Base class for representing a vector index for performing nearest neighbor searches. * * @param <T> the type of elements stored in the vector index */ public abstract class AbstractTinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> { + /** + * Default number of nearest neighbors to return + */ + private static final int DEFAULT_K = 10; protected AbstractTinkerVectorIndex(final AbstractTinkerGraph graph, final Class<T> indexClass) { super(graph, indexClass); @@ -40,7 +45,7 @@ public abstract class AbstractTinkerVectorIndex<T extends Element> extends Abstr * @param k the number of nearest neighbors to return * @return a list of elements sorted by distance */ - public abstract List<T> findNearest(final String key, final float[] vector, final int k); + public abstract List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k); /** * Searches for nearest neighbors in the vector index with the default k. @@ -49,6 +54,28 @@ public abstract class AbstractTinkerVectorIndex<T extends Element> extends Abstr * @param vector the query vector * @return a list of elements sorted by distance */ - public abstract List<T> findNearest(final String key, final float[] vector); + public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector) { + return findNearest(key, vector, DEFAULT_K); + } + /** + * Searches for nearest neighbors in the vector index. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of elements sorted by distance + */ + public abstract List<T> findNearestElements(final String key, final float[] vector, final int k); + + /** + * Searches for nearest neighbors in the vector index with the default k. + * + * @param key the property key + * @param vector the query vector + * @return a list of elements sorted by distance + */ + public List<T> findNearestElements(final String key, final float[] vector) { + return findNearestElements(key, vector, DEFAULT_K); + } } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java index 177378b620..b92ab58688 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java @@ -97,6 +97,7 @@ public class TinkerGraph extends AbstractTinkerGraph { if (graphLocation != null) loadGraph(); serviceRegistry = new TinkerServiceRegistry(this); + serviceRegistry.registerService(new TinkerServiceRegistry(this)); configuration.getList(String.class, GREMLIN_TINKERGRAPH_SERVICE, Collections.emptyList()).forEach(serviceClass -> serviceRegistry.registerService(instantiate(serviceClass))); } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexElement.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexElement.java new file mode 100644 index 0000000000..d8e43afb4c --- /dev/null +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexElement.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.tinkergraph.structure; + +import java.util.HashMap; +import java.util.Map; + +public class TinkerIndexElement<T> { + private final T element; + private final float distance; + + public TinkerIndexElement(final T element, final float distance) { + this.element = element; + this.distance = distance; + } + + public T getElement() { + return element; + } + + public float getDistance() { + return distance; + } + + public Map<String, Object> toMap() { + return new HashMap<String, Object>() {{ + put("element", element); + put("distance", distance); + }}; + } +} \ No newline at end of file diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java index 7c37b59757..0a4c2c2758 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java @@ -35,7 +35,7 @@ public final class TinkerIndexHelper { * @return a list of vertices sorted by distance */ public static List<TinkerVertex> findNearestVertices(final AbstractTinkerGraph graph, final String key, final float[] vector, final int k) { - return null == graph.vertexVectorIndex ? Collections.emptyList() : graph.vertexVectorIndex.findNearest(key, vector, k); + return null == graph.vertexVectorIndex ? Collections.emptyList() : graph.vertexVectorIndex.findNearestElements(key, vector, k); } /** @@ -47,7 +47,7 @@ public final class TinkerIndexHelper { * @return a list of vertices sorted by distance */ public static List<TinkerVertex> findNearestVertices(final AbstractTinkerGraph graph, final String key, final float[] vector) { - return null == graph.vertexVectorIndex ? Collections.emptyList() : graph.vertexVectorIndex.findNearest(key, vector); + return null == graph.vertexVectorIndex ? Collections.emptyList() : graph.vertexVectorIndex.findNearestElements(key, vector); } /** @@ -60,7 +60,7 @@ public final class TinkerIndexHelper { * @return a list of edges sorted by distance */ public static List<TinkerEdge> findNearestEdges(final AbstractTinkerGraph graph, final String key, final float[] vector, final int k) { - return null == graph.edgeVectorIndex ? Collections.emptyList() : graph.edgeVectorIndex.findNearest(key, vector, k); + return null == graph.edgeVectorIndex ? Collections.emptyList() : graph.edgeVectorIndex.findNearestElements(key, vector, k); } /** @@ -72,7 +72,7 @@ public final class TinkerIndexHelper { * @return a list of edges sorted by distance */ public static List<TinkerEdge> findNearestEdges(final AbstractTinkerGraph graph, final String key, final float[] vector) { - return null == graph.edgeVectorIndex ? Collections.emptyList() : graph.edgeVectorIndex.findNearest(key, vector); + return null == graph.edgeVectorIndex ? Collections.emptyList() : graph.edgeVectorIndex.findNearestElements(key, vector); } public static List<TinkerVertex> queryVertexIndex(final AbstractTinkerGraph graph, final String key, final Object value) { diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java index 111d4fcb7b..864a8242ba 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java @@ -240,24 +240,31 @@ final class TinkerTransactionVectorIndex<T extends TinkerElement> extends Abstra * @param k the number of nearest neighbors to return * @return a list of elements sorted by distance */ - public List<T> findNearest(final String key, final float[] vector, final int k) { + public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k) { if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) throw new IllegalArgumentException("The key '" + key + "' is not indexed"); final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); - return nearest.stream().map(sr -> sr.item().element).collect(Collectors.toList()); + return nearest.stream().map(sr -> + new TinkerIndexElement<>(sr.item().element, sr.distance())).collect(Collectors.toList()); } /** - * Searches for nearest neighbors in the vector index with the default k. + * Searches for nearest neighbors in the vector index. * * @param key the property key * @param vector the query vector + * @param k the number of nearest neighbors to return * @return a list of elements sorted by distance */ - public List<T> findNearest(final String key, final float[] vector) { - return findNearest(key, vector, DEFAULT_K); + public List<T> findNearestElements(final String key, final float[] vector, final int k) { + if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + throw new IllegalArgumentException("The key '" + key + "' is not indexed"); + + final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); + final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); + return nearest.stream().map(sr -> sr.item().element).collect(Collectors.toList()); } /** diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java index f280d0a3fb..90317fb70c 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java @@ -46,11 +46,6 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorInd */ protected Map<String, Index<Object, float[], ElementItem, Float>> vectorIndices = new ConcurrentHashMap<>(); - /** - * Default number of nearest neighbors to return - */ - private static final int DEFAULT_K = 10; - /** * Default M parameter for HNSW index */ @@ -233,24 +228,31 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorInd * @param k the number of nearest neighbors to return * @return a list of elements sorted by distance */ - public List<T> findNearest(final String key, final float[] vector, final int k) { + public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k) { if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) throw new IllegalArgumentException("The key '" + key + "' is not indexed"); final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); - return nearest.stream().map(sr -> sr.item().element).collect(Collectors.toList()); + return nearest.stream().map(sr -> + new TinkerIndexElement<>(sr.item().element, sr.distance())).collect(Collectors.toList()); } /** - * Searches for nearest neighbors in the vector index with the default k. + * Searches for nearest neighbors in the vector index. * * @param key the property key * @param vector the query vector + * @param k the number of nearest neighbors to return * @return a list of elements sorted by distance */ - public List<T> findNearest(final String key, final float[] vector) { - return findNearest(key, vector, DEFAULT_K); + public List<T> findNearestElements(final String key, final float[] vector, final int k) { + if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + throw new IllegalArgumentException("The key '" + key + "' is not indexed"); + + final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); + final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); + return nearest.stream().map(sr -> sr.item().element).collect(Collectors.toList()); } /** diff --git a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java index 0b1d63af86..daa4630766 100644 --- a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java +++ b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java @@ -36,6 +36,7 @@ import java.util.Map; import static org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource.traversal; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.number.OrderingComparison.lessThanOrEqualTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; @@ -53,7 +54,7 @@ public class TinkerGraphVectorIndexTest { @Parameterized.Parameter public AbstractTinkerGraph graph; - @Parameterized.Parameters + @Parameterized.Parameters(name = "{0}") public static Collection<Object[]> data() { return Arrays.asList(new Object[][]{ {TinkerGraph.open()}, @@ -91,7 +92,7 @@ public class TinkerGraphVectorIndexTest { graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); - final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + final List<TinkerVertex> nearest = graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); assertNotNull(nearest); assertEquals(2, nearest.size()); assertEquals("Alice", nearest.get(0).value("name")); @@ -111,7 +112,7 @@ public class TinkerGraphVectorIndexTest { g.V().has("name", "Bob").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); tryCommitChanges(graph); - final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + final List<TinkerVertex> nearest = graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); assertNotNull(nearest); assertEquals(2, nearest.size()); assertEquals("Alice", nearest.get(0).value("name")); @@ -132,7 +133,7 @@ public class TinkerGraphVectorIndexTest { g.V().has("name", "Bob").drop().iterate(); tryCommitChanges(graph); - final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2); + final List<TinkerVertex> nearest = graph.findNearestVerticesOnly("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2); assertNotNull(nearest); assertEquals(2, nearest.size()); assertThat(nearest.stream().noneMatch(v -> v.value("name").equals("Bob")), is(true)); @@ -153,7 +154,7 @@ public class TinkerGraphVectorIndexTest { assertThat(graph.getIndexedKeys(Vertex.class).contains("embedding"), is(false)); try { - graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); fail("Should have thrown exception since the index was removed"); } catch (IllegalArgumentException ex) { } } @@ -173,7 +174,7 @@ public class TinkerGraphVectorIndexTest { graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); - final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + final List<TinkerEdge> nearest = graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); assertNotNull(nearest); assertEquals(2, nearest.size()); assertEquals(0.8f, (float) nearest.get(0).value("strength"), 0.0001f); @@ -195,7 +196,7 @@ public class TinkerGraphVectorIndexTest { g.E(edge.id()).property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); tryCommitChanges(graph); - final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + final List<TinkerEdge> nearest = graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); assertNotNull(nearest); assertEquals(2, nearest.size()); assertEquals(0.8f, (float) nearest.get(0).value("strength"), 0.0001f); @@ -219,7 +220,7 @@ public class TinkerGraphVectorIndexTest { g.E(edge.id()).drop().iterate(); tryCommitChanges(graph); - final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2); + final List<TinkerEdge> nearest = graph.findNearestEdgesOnly("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2); assertNotNull(nearest); assertEquals(2, nearest.size()); assertThat(nearest.stream().noneMatch(e -> e.value("strength").equals(0.6f)), is(true)); @@ -242,7 +243,7 @@ public class TinkerGraphVectorIndexTest { assertThat(graph.getIndexedKeys(Edge.class).contains("embedding"), is(false)); try { - graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); fail("Should have thrown exception since the index was removed"); } catch (IllegalArgumentException ex) { } } @@ -285,7 +286,7 @@ public class TinkerGraphVectorIndexTest { tryRollbackChanges(graph); // Bob's embedding should still be [0.0f, 1.0f, 0.0f] - final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{0.0f, 1.0f, 0.0f}, 1); + final List<TinkerVertex> nearest = graph.findNearestVerticesOnly("embedding", new float[]{0.0f, 1.0f, 0.0f}, 1); assertNotNull(nearest); assertEquals(1, nearest.size()); assertEquals("Bob", nearest.get(0).value("name")); @@ -294,7 +295,7 @@ public class TinkerGraphVectorIndexTest { @Test public void shouldHandleEmptyGraphForNearestVertices() { graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); - final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + final List<TinkerVertex> nearest = graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); assertNotNull(nearest); assertEquals(0, nearest.size()); } @@ -302,31 +303,208 @@ public class TinkerGraphVectorIndexTest { @Test public void shouldHandleEmptyGraphForNearestEdges() { graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); - final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + final List<TinkerEdge> nearest = graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); assertNotNull(nearest); assertEquals(0, nearest.size()); } @Test(expected = IllegalStateException.class) public void shouldThrowExceptionWhenIndexNotCreatedForNearestVertices() { - graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); } @Test(expected = IllegalStateException.class) public void shouldThrowExceptionWhenIndexNotCreatedForNearestEdges() { - graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); } @Test(expected = IllegalStateException.class) public void shouldThrowExceptionWhenIndexNotCreatedForNearestVerticesNoDefaultCount() { - graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}); + graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}); } @Test(expected = IllegalStateException.class) public void shouldThrowExceptionWhenIndexNotCreatedForNearestEdgesNoDefaultCount() { + graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}); + } + + @Test + public void shouldFindNearestVerticesWithDefaultK() { + final GraphTraversalSource g = traversal().with(graph); + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).iterate(); + g.addV("person").property("name", "Dave").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + + tryCommitChanges(graph); + + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + + final List<TinkerIndexElement<TinkerVertex>> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}); + assertNotNull(nearest); + assertEquals(4, nearest.size()); + + // Sort by distance first, then by "strength" to ensure deterministic order + nearest.sort((e1, e2) -> { + int distanceComparison = Float.compare(e1.getDistance(), e2.getDistance()); + if (distanceComparison != 0) return distanceComparison; + return e1.getElement().value("name").toString().compareTo(e2.getElement().value("name")); + }); + + assertEquals("Alice", nearest.get(0).getElement().value("name")); + assertEquals("Dave", nearest.get(1).getElement().value("name")); + assertEquals("Bob", nearest.get(2).getElement().value("name")); + assertEquals("Charlie", nearest.get(3).getElement().value("name")); + + // ensure that the finds are descending order given distance + for (int i = 0; i < nearest.size() - 1; i++) { + assertThat(nearest.get(i).getDistance(), is(lessThanOrEqualTo(nearest.get(i + 1).getDistance()))); + } + } + + @Test + public void shouldFindNearestVerticesWithSpecifiedK() { + final GraphTraversalSource g = traversal().with(graph); + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).iterate(); + g.addV("person").property("name", "Dave").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + + tryCommitChanges(graph); + + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + + final List<TinkerIndexElement<TinkerVertex>> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(2, nearest.size()); + assertEquals("Alice", nearest.get(0).getElement().value("name")); + assertEquals("Dave", nearest.get(1).getElement().value("name")); + + // ensure that the finds are descending order given distance + for (int i = 0; i < nearest.size() - 1; i++) { + assertThat(nearest.get(i).getDistance(), is(lessThanOrEqualTo(nearest.get(i + 1).getDistance()))); + } + } + + @Test(expected = IllegalStateException.class) + public void shouldThrowExceptionWhenIndexNotCreatedForFindNearestVertices() { + graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}); + } + + @Test(expected = IllegalStateException.class) + public void shouldThrowExceptionWhenIndexNotCreatedForFindNearestVerticesWithK() { + graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + } + + @Test + public void shouldHandleEmptyGraphForFindNearestVertices() { + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final List<TinkerIndexElement<TinkerVertex>> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}); + assertNotNull(nearest); + assertEquals(0, nearest.size()); + } + + @Test + public void shouldHandleEmptyGraphForFindNearestVerticesWithK() { + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final List<TinkerIndexElement<TinkerVertex>> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(0, nearest.size()); + } + + @Test + public void shouldFindNearestEdgesWithDefaultK() { + final GraphTraversalSource g = traversal().with(graph); + final Vertex alice = g.addV("person").property("name", "Alice").next(); + final Vertex bob = g.addV("person").property("name", "Bob").next(); + final Vertex charlie = g.addV("person").property("name", "Charlie").next(); + final Vertex dave = g.addV("person").property("name", "Dave").next(); + g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 8).iterate(); + g.addE("knows").from(bob).to(charlie).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 6).iterate(); + g.addE("knows").from(charlie).to(dave).property("embedding", new float[]{0.0f, 0.0f, 1.0f}).property("strength", 7).iterate(); + g.addE("knows").from(alice).to(dave).property("embedding", new float[]{0.9f, 0.1f, 0.0f}).property("strength", 9).iterate(); + + tryCommitChanges(graph); + + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + + final List<TinkerIndexElement<TinkerEdge>> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}); + assertNotNull(nearest); + assertEquals(4, nearest.size()); + + // Sort by distance first, then by "strength" to ensure deterministic order + nearest.sort((e1, e2) -> { + int distanceComparison = Float.compare(e1.getDistance(), e2.getDistance()); + if (distanceComparison != 0) return distanceComparison; + return Integer.compare((int) e1.getElement().value("strength"), (int) e2.getElement().value("strength")); + }); + + // Assert the sorted results + assertEquals(8, (int) nearest.get(0).getElement().value("strength")); + assertEquals(9, (int) nearest.get(1).getElement().value("strength")); + assertEquals(6, (int) nearest.get(2).getElement().value("strength")); + assertEquals(7, (int) nearest.get(3).getElement().value("strength")); + + // Validate distances are in non-decreasing order + for (int i = 0; i < nearest.size() - 1; i++) { + assertThat(nearest.get(i).getDistance(), is(lessThanOrEqualTo(nearest.get(i + 1).getDistance()))); + } + } + + @Test + public void shouldFindNearestEdgesWithSpecifiedK() { + final GraphTraversalSource g = traversal().with(graph); + final Vertex alice = g.addV("person").property("name", "Alice").next(); + final Vertex bob = g.addV("person").property("name", "Bob").next(); + final Vertex charlie = g.addV("person").property("name", "Charlie").next(); + final Vertex dave = g.addV("person").property("name", "Dave").next(); + g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 8).iterate(); + g.addE("knows").from(bob).to(charlie).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 6).iterate(); + g.addE("knows").from(charlie).to(dave).property("embedding", new float[]{0.0f, 0.0f, 1.0f}).property("strength", 7).iterate(); + g.addE("knows").from(alice).to(dave).property("embedding", new float[]{0.9f, 0.1f, 0.0f}).property("strength", 9).iterate(); + + tryCommitChanges(graph); + + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + + final List<TinkerIndexElement<TinkerEdge>> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(2, nearest.size()); + assertEquals(8, (int) nearest.get(0).getElement().value("strength")); + assertEquals(9, (int) nearest.get(1).getElement().value("strength")); + + // Validate distances are in non-decreasing order + for (int i = 0; i < nearest.size() - 1; i++) { + assertThat(nearest.get(i).getDistance(), is(lessThanOrEqualTo(nearest.get(i + 1).getDistance()))); + } + } + + @Test(expected = IllegalStateException.class) + public void shouldThrowExceptionWhenIndexNotCreatedForFindNearestEdges() { graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}); } + @Test(expected = IllegalStateException.class) + public void shouldThrowExceptionWhenIndexNotCreatedForFindNearestEdgesWithK() { + graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + } + + @Test + public void shouldHandleEmptyGraphForFindNearestEdges() { + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + final List<TinkerIndexElement<TinkerEdge>> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}); + assertNotNull(nearest); + assertEquals(0, nearest.size()); + } + + @Test + public void shouldHandleEmptyGraphForFindNearestEdgesWithK() { + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + final List<TinkerIndexElement<TinkerEdge>> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(0, nearest.size()); + } + private void tryCommitChanges(final Graph graph) { if (graph.features().graph().supportsTransactions()) graph.tx().commit();
