This is an automated email from the ASF dual-hosted git repository.

spmallette pushed a commit to branch TINKERPOP-3158
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git

commit 6b2713b776206c3b5c6eaf43ad2e33c320288092
Author: Stephen Mallette <[email protected]>
AuthorDate: Fri May 2 16:13:47 2025 -0400

    return distances with element
---
 .../services/TinkerTextSearchFactory.java          |   2 +-
 .../services/TinkerVectorSearchFactory.java        | 110 +++++++++++
 .../tinkergraph/structure/AbstractTinkerGraph.java |  71 ++++++-
 .../structure/AbstractTinkerVectorIndex.java       |  31 ++-
 .../gremlin/tinkergraph/structure/TinkerGraph.java |   1 +
 .../tinkergraph/structure/TinkerIndexElement.java  |  47 +++++
 .../tinkergraph/structure/TinkerIndexHelper.java   |   8 +-
 .../structure/TinkerTransactionVectorIndex.java    |  17 +-
 .../tinkergraph/structure/TinkerVectorIndex.java   |  22 ++-
 .../structure/TinkerGraphVectorIndexTest.java      | 208 +++++++++++++++++++--
 10 files changed, 471 insertions(+), 46 deletions(-)

diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerTextSearchFactory.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerTextSearchFactory.java
index b9675856e6..cdb1582493 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerTextSearchFactory.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerTextSearchFactory.java
@@ -41,7 +41,7 @@ import static 
org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap;
  */
 public class TinkerTextSearchFactory<I, R> extends 
TinkerServiceRegistry.TinkerServiceFactory<I, R> implements Service<I, R> {
 
-    public static final String NAME = "tinker.search";
+    public static final String NAME = "tinker.search.text";
 
     public interface Params {
         /**
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java
new file mode 100644
index 0000000000..fb0b3c76f4
--- /dev/null
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.tinkerpop.gremlin.tinkergraph.services;
+
+import org.apache.tinkerpop.gremlin.process.traversal.Traverser;
+import org.apache.tinkerpop.gremlin.structure.Direction;
+import org.apache.tinkerpop.gremlin.structure.Edge;
+import org.apache.tinkerpop.gremlin.structure.Element;
+import org.apache.tinkerpop.gremlin.structure.Property;
+import org.apache.tinkerpop.gremlin.structure.Vertex;
+import org.apache.tinkerpop.gremlin.structure.VertexProperty;
+import org.apache.tinkerpop.gremlin.structure.service.Service;
+import org.apache.tinkerpop.gremlin.structure.util.CloseableIterator;
+import org.apache.tinkerpop.gremlin.tinkergraph.structure.AbstractTinkerGraph;
+import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerHelper;
+import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerIndexElement;
+import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.LongStream;
+
+import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.KEY;
+import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.TOP_K;
+import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap;
+
+/**
+ *
+ */
+public class TinkerVectorSearchFactory extends 
TinkerServiceRegistry.TinkerServiceFactory<Element, Map<String, Object>> 
implements Service<Element, Map<String, Object>> {
+
+    public static final String NAME = "tinker.search.vector.topKByVertex";
+
+    public interface Params {
+        /**
+         * Specify the key storing the embedding
+         */
+        String KEY = "key";
+        /**
+         * Number of results to return
+         */
+        String TOP_K = "topK";
+
+        Map DESCRIBE = asMap(
+                KEY, "Specify they key storing the embedding for the vector 
search",
+                TOP_K, "Number of results to return (optional, defaults to 10)"
+        );
+    }
+
+    public TinkerVectorSearchFactory(final AbstractTinkerGraph graph) {
+        super(graph, NAME);
+    }
+
+    @Override
+    public Type getType() {
+        return Type.Streaming;
+    }
+
+    @Override
+    public Map describeParams() {
+        return Params.DESCRIBE;
+    }
+
+    @Override
+    public Set<Type> getSupportedTypes() {
+        return Collections.singleton(Type.Streaming);
+    }
+
+    @Override
+    public Service<Element, Map<String, Object>> createService(final boolean 
isStart, final Map params) {
+        if (isStart) {
+            throw new 
UnsupportedOperationException(Exceptions.cannotStartTraversal);
+        }
+        return this;
+    }
+
+    @Override
+    public CloseableIterator<Map<String,Object>> execute(final 
ServiceCallContext ctx, final Traverser.Admin<Element> in, final Map params) {
+        final String key = (String) params.get(KEY);
+        final int k = (int) params.getOrDefault(TOP_K, 10);
+        final Element e = in.get();
+        final float[] embedding = e.value(key);
+        return CloseableIterator.of(graph.findNearestVertices(key, embedding, 
k).stream().map(TinkerIndexElement::toMap).iterator());
+    }
+
+    @Override
+    public void close() {}
+
+}
+
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java
index 5ace0fd93a..0989e172d3 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java
@@ -45,7 +45,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.stream.Stream;
 
@@ -304,10 +303,37 @@ public abstract class AbstractTinkerGraph implements 
Graph {
      * @param k      the number of nearest neighbors to return
      * @return a list of vertices sorted by distance
      */
-    public List<Vertex> findNearestVertices(final String key, final float[] 
vector, final int k) {
+    public List<TinkerIndexElement<TinkerVertex>> findNearestVertices(final 
String key, final float[] vector, final int k) {
         if (null == this.vertexVectorIndex)
             throw new IllegalStateException("Vector index not created for 
vertices on key: '" + key + "'");
-        return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector, 
k));
+        return this.vertexVectorIndex.findNearest(key, vector, k);
+    }
+
+    /**
+     * Find the nearest vertices to the given vector in the vector index for 
the specified property key.
+     *
+     * @param key    the property key
+     * @param vector the query vector
+     * @return a list of vertices sorted by distance
+     */
+    public List<TinkerIndexElement<TinkerVertex>> findNearestVertices(final 
String key, final float[] vector) {
+        if (null == this.vertexVectorIndex)
+            throw new IllegalStateException("Vector index not created for 
vertices on key: '" + key + "'");
+        return this.vertexVectorIndex.findNearest(key, vector);
+    }
+
+    /**
+     * Find the nearest vertices to the given vector in the vector index for 
the specified property key.
+     *
+     * @param key    the property key
+     * @param vector the query vector
+     * @param k      the number of nearest neighbors to return
+     * @return a list of vertices sorted by distance
+     */
+    public List<TinkerVertex> findNearestVerticesOnly(final String key, final 
float[] vector, final int k) {
+        if (null == this.vertexVectorIndex)
+            throw new IllegalStateException("Vector index not created for 
vertices on key: '" + key + "'");
+        return this.vertexVectorIndex.findNearestElements(key, vector, k);
     }
 
     /**
@@ -318,10 +344,37 @@ public abstract class AbstractTinkerGraph implements 
Graph {
      * @param vector the query vector
      * @return a list of vertices sorted by distance
      */
-    public List<Vertex> findNearestVertices(final String key, final float[] 
vector) {
+    public List<TinkerVertex> findNearestVerticesOnly(final String key, final 
float[] vector) {
         if (null == this.vertexVectorIndex)
             throw new IllegalStateException("Vector index not created for 
vertices on key: '" + key + "'");
-        return new ArrayList<>(this.vertexVectorIndex.findNearest(key, 
vector));
+        return this.vertexVectorIndex.findNearestElements(key, vector);
+    }
+
+    /**
+     * Find the nearest edges to the given vector in the vector index for the 
specified property key.
+     *
+     * @param key    the property key
+     * @param vector the query vector
+     * @param k      the number of nearest neighbors to return
+     * @return a list of vertices sorted by distance
+     */
+    public List<TinkerIndexElement<TinkerEdge>> findNearestEdges(final String 
key, final float[] vector, final int k) {
+        if (null == this.edgeVectorIndex)
+            throw new IllegalStateException("Vector index not created for 
edges on key: '" + key + "'");
+        return this.edgeVectorIndex.findNearest(key, vector, k);
+    }
+
+    /**
+     * Find the nearest edges to the given vector in the vector index for the 
specified property key.
+     *
+     * @param key    the property key
+     * @param vector the query vector
+     * @return a list of vertices sorted by distance
+     */
+    public List<TinkerIndexElement<TinkerEdge>> findNearestEdges(final String 
key, final float[] vector) {
+        if (null == this.edgeVectorIndex)
+            throw new IllegalStateException("Vector index not created for 
edges on key: '" + key + "'");
+        return this.edgeVectorIndex.findNearest(key, vector);
     }
 
     /**
@@ -332,10 +385,10 @@ public abstract class AbstractTinkerGraph implements 
Graph {
      * @param k      the number of nearest neighbors to return
      * @return a list of edges sorted by distance
      */
-    public List<Edge> findNearestEdges(final String key, final float[] vector, 
final int k) {
+    public List<TinkerEdge> findNearestEdgesOnly(final String key, final 
float[] vector, final int k) {
         if (null == this.edgeVectorIndex)
             throw new IllegalStateException("Vector index not created for 
edges on key: '" + key + "'");
-        return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector, 
k));
+        return new ArrayList<>(this.edgeVectorIndex.findNearestElements(key, 
vector, k));
     }
 
     /**
@@ -346,10 +399,10 @@ public abstract class AbstractTinkerGraph implements 
Graph {
      * @param vector the query vector
      * @return a list of edges sorted by distance
      */
-    public List<Edge> findNearestEdges(final String key, final float[] vector) 
{
+    public List<TinkerEdge> findNearestEdgesOnly(final String key, final 
float[] vector) {
         if (null == this.edgeVectorIndex)
             throw new IllegalStateException("Vector index not created for 
edges on key: '" + key + "'");
-        return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector));
+        return new ArrayList<>(this.edgeVectorIndex.findNearestElements(key, 
vector));
     }
 
     ///////////// Utility methods ///////////////
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java
index 3a4e67cb25..7826d402e8 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java
@@ -21,12 +21,17 @@ package org.apache.tinkerpop.gremlin.tinkergraph.structure;
 import org.apache.tinkerpop.gremlin.structure.Element;
 
 import java.util.List;
+
 /**
  * Base class for representing a vector index for performing nearest neighbor 
searches.
  *
  * @param <T> the type of elements stored in the vector index
  */
 public abstract class AbstractTinkerVectorIndex<T extends Element> extends 
AbstractTinkerIndex<T> {
+    /**
+     * Default number of nearest neighbors to return
+     */
+    private static final int DEFAULT_K = 10;
 
     protected AbstractTinkerVectorIndex(final AbstractTinkerGraph graph, final 
Class<T> indexClass) {
         super(graph, indexClass);
@@ -40,7 +45,7 @@ public abstract class AbstractTinkerVectorIndex<T extends 
Element> extends Abstr
      * @param k      the number of nearest neighbors to return
      * @return a list of elements sorted by distance
      */
-    public abstract List<T> findNearest(final String key, final float[] 
vector, final int k);
+    public abstract List<TinkerIndexElement<T>> findNearest(final String key, 
final float[] vector, final int k);
 
     /**
      * Searches for nearest neighbors in the vector index with the default k.
@@ -49,6 +54,28 @@ public abstract class AbstractTinkerVectorIndex<T extends 
Element> extends Abstr
      * @param vector the query vector
      * @return a list of elements sorted by distance
      */
-    public abstract List<T> findNearest(final String key, final float[] 
vector);
+    public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector) {
+        return findNearest(key, vector, DEFAULT_K);
+    }
 
+    /**
+     * Searches for nearest neighbors in the vector index.
+     *
+     * @param key    the property key
+     * @param vector the query vector
+     * @param k      the number of nearest neighbors to return
+     * @return a list of elements sorted by distance
+     */
+    public abstract List<T> findNearestElements(final String key, final 
float[] vector, final int k);
+
+    /**
+     * Searches for nearest neighbors in the vector index with the default k.
+     *
+     * @param key    the property key
+     * @param vector the query vector
+     * @return a list of elements sorted by distance
+     */
+    public List<T> findNearestElements(final String key, final float[] vector) 
{
+        return findNearestElements(key, vector, DEFAULT_K);
+    }
 }
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java
index 177378b620..b92ab58688 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java
@@ -97,6 +97,7 @@ public class TinkerGraph extends AbstractTinkerGraph {
         if (graphLocation != null) loadGraph();
 
         serviceRegistry = new TinkerServiceRegistry(this);
+        serviceRegistry.registerService(new TinkerServiceRegistry(this));
         configuration.getList(String.class, GREMLIN_TINKERGRAPH_SERVICE, 
Collections.emptyList()).forEach(serviceClass ->
                 serviceRegistry.registerService(instantiate(serviceClass)));
     }
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexElement.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexElement.java
new file mode 100644
index 0000000000..d8e43afb4c
--- /dev/null
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexElement.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.tinkerpop.gremlin.tinkergraph.structure;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class TinkerIndexElement<T> {
+    private final T element;
+    private final float distance;
+
+    public TinkerIndexElement(final T element, final float distance) {
+        this.element = element;
+        this.distance = distance;
+    }
+
+    public T getElement() {
+        return element;
+    }
+
+    public float getDistance() {
+        return distance;
+    }
+
+    public Map<String, Object> toMap() {
+        return new HashMap<String, Object>() {{
+            put("element", element);
+            put("distance", distance);
+        }};
+    }
+}
\ No newline at end of file
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java
index 7c37b59757..0a4c2c2758 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java
@@ -35,7 +35,7 @@ public final class TinkerIndexHelper {
      * @return a list of vertices sorted by distance
      */
     public static List<TinkerVertex> findNearestVertices(final 
AbstractTinkerGraph graph, final String key, final float[] vector, final int k) 
{
-        return null == graph.vertexVectorIndex ? Collections.emptyList() : 
graph.vertexVectorIndex.findNearest(key, vector, k);
+        return null == graph.vertexVectorIndex ? Collections.emptyList() : 
graph.vertexVectorIndex.findNearestElements(key, vector, k);
     }
 
     /**
@@ -47,7 +47,7 @@ public final class TinkerIndexHelper {
      * @return a list of vertices sorted by distance
      */
     public static List<TinkerVertex> findNearestVertices(final 
AbstractTinkerGraph graph, final String key, final float[] vector) {
-        return null == graph.vertexVectorIndex ? Collections.emptyList() : 
graph.vertexVectorIndex.findNearest(key, vector);
+        return null == graph.vertexVectorIndex ? Collections.emptyList() : 
graph.vertexVectorIndex.findNearestElements(key, vector);
     }
 
     /**
@@ -60,7 +60,7 @@ public final class TinkerIndexHelper {
      * @return a list of edges sorted by distance
      */
     public static List<TinkerEdge> findNearestEdges(final AbstractTinkerGraph 
graph, final String key, final float[] vector, final int k) {
-        return null == graph.edgeVectorIndex ? Collections.emptyList() : 
graph.edgeVectorIndex.findNearest(key, vector, k);
+        return null == graph.edgeVectorIndex ? Collections.emptyList() : 
graph.edgeVectorIndex.findNearestElements(key, vector, k);
     }
 
     /**
@@ -72,7 +72,7 @@ public final class TinkerIndexHelper {
      * @return a list of edges sorted by distance
      */
     public static List<TinkerEdge> findNearestEdges(final AbstractTinkerGraph 
graph, final String key, final float[] vector) {
-        return null == graph.edgeVectorIndex ? Collections.emptyList() : 
graph.edgeVectorIndex.findNearest(key, vector);
+        return null == graph.edgeVectorIndex ? Collections.emptyList() : 
graph.edgeVectorIndex.findNearestElements(key, vector);
     }
 
     public static List<TinkerVertex> queryVertexIndex(final 
AbstractTinkerGraph graph, final String key, final Object value) {
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java
index 111d4fcb7b..864a8242ba 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java
@@ -240,24 +240,31 @@ final class TinkerTransactionVectorIndex<T extends 
TinkerElement> extends Abstra
      * @param k      the number of nearest neighbors to return
      * @return a list of elements sorted by distance
      */
-    public List<T> findNearest(final String key, final float[] vector, final 
int k) {
+    public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector, final int k) {
         if (!this.indexedKeys.contains(key) || 
!this.vectorIndices.containsKey(key))
             throw new IllegalArgumentException("The key '" + key + "' is not 
indexed");
 
         final Index<Object, float[], ElementItem, Float> index = 
this.vectorIndices.get(key);
         final List<SearchResult<ElementItem, Float>> nearest = 
index.findNearest(vector, k);
-        return nearest.stream().map(sr -> 
sr.item().element).collect(Collectors.toList());
+        return nearest.stream().map(sr ->
+                new TinkerIndexElement<>(sr.item().element, 
sr.distance())).collect(Collectors.toList());
     }
 
     /**
-     * Searches for nearest neighbors in the vector index with the default k.
+     * Searches for nearest neighbors in the vector index.
      *
      * @param key    the property key
      * @param vector the query vector
+     * @param k      the number of nearest neighbors to return
      * @return a list of elements sorted by distance
      */
-    public List<T> findNearest(final String key, final float[] vector) {
-        return findNearest(key, vector, DEFAULT_K);
+    public List<T> findNearestElements(final String key, final float[] vector, 
final int k) {
+        if (!this.indexedKeys.contains(key) || 
!this.vectorIndices.containsKey(key))
+            throw new IllegalArgumentException("The key '" + key + "' is not 
indexed");
+
+        final Index<Object, float[], ElementItem, Float> index = 
this.vectorIndices.get(key);
+        final List<SearchResult<ElementItem, Float>> nearest = 
index.findNearest(vector, k);
+        return nearest.stream().map(sr -> 
sr.item().element).collect(Collectors.toList());
     }
 
     /**
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java
index f280d0a3fb..90317fb70c 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java
@@ -46,11 +46,6 @@ final class TinkerVectorIndex<T extends Element> extends 
AbstractTinkerVectorInd
      */
     protected Map<String, Index<Object, float[], ElementItem, Float>> 
vectorIndices = new ConcurrentHashMap<>();
 
-    /**
-     * Default number of nearest neighbors to return
-     */
-    private static final int DEFAULT_K = 10;
-
     /**
      * Default M parameter for HNSW index
      */
@@ -233,24 +228,31 @@ final class TinkerVectorIndex<T extends Element> extends 
AbstractTinkerVectorInd
      * @param k      the number of nearest neighbors to return
      * @return a list of elements sorted by distance
      */
-    public List<T> findNearest(final String key, final float[] vector, final 
int k) {
+    public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector, final int k) {
         if (!this.indexedKeys.contains(key) || 
!this.vectorIndices.containsKey(key))
             throw new IllegalArgumentException("The key '" + key + "' is not 
indexed");
 
         final Index<Object, float[], ElementItem, Float> index = 
this.vectorIndices.get(key);
         final List<SearchResult<ElementItem, Float>> nearest = 
index.findNearest(vector, k);
-        return nearest.stream().map(sr -> 
sr.item().element).collect(Collectors.toList());
+        return nearest.stream().map(sr ->
+                new TinkerIndexElement<>(sr.item().element, 
sr.distance())).collect(Collectors.toList());
     }
 
     /**
-     * Searches for nearest neighbors in the vector index with the default k.
+     * Searches for nearest neighbors in the vector index.
      *
      * @param key    the property key
      * @param vector the query vector
+     * @param k      the number of nearest neighbors to return
      * @return a list of elements sorted by distance
      */
-    public List<T> findNearest(final String key, final float[] vector) {
-        return findNearest(key, vector, DEFAULT_K);
+    public List<T> findNearestElements(final String key, final float[] vector, 
final int k) {
+        if (!this.indexedKeys.contains(key) || 
!this.vectorIndices.containsKey(key))
+            throw new IllegalArgumentException("The key '" + key + "' is not 
indexed");
+
+        final Index<Object, float[], ElementItem, Float> index = 
this.vectorIndices.get(key);
+        final List<SearchResult<ElementItem, Float>> nearest = 
index.findNearest(vector, k);
+        return nearest.stream().map(sr -> 
sr.item().element).collect(Collectors.toList());
     }
 
     /**
diff --git 
a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java
 
b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java
index 0b1d63af86..daa4630766 100644
--- 
a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java
+++ 
b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java
@@ -36,6 +36,7 @@ import java.util.Map;
 import static 
org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource.traversal;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.number.OrderingComparison.lessThanOrEqualTo;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
@@ -53,7 +54,7 @@ public class TinkerGraphVectorIndexTest {
     @Parameterized.Parameter
     public AbstractTinkerGraph graph;
 
-    @Parameterized.Parameters
+    @Parameterized.Parameters(name = "{0}")
     public static Collection<Object[]> data() {
         return Arrays.asList(new Object[][]{
                 {TinkerGraph.open()},
@@ -91,7 +92,7 @@ public class TinkerGraphVectorIndexTest {
 
         graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, 
indexConfig);
 
-        final List<Vertex> nearest = graph.findNearestVertices("embedding", 
new float[]{1.0f, 0.0f, 0.0f}, 2);
+        final List<TinkerVertex> nearest = 
graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
         assertNotNull(nearest);
         assertEquals(2, nearest.size());
         assertEquals("Alice", nearest.get(0).value("name"));
@@ -111,7 +112,7 @@ public class TinkerGraphVectorIndexTest {
         g.V().has("name", "Bob").property("embedding", new float[]{0.9f, 0.1f, 
0.0f}).iterate();
         tryCommitChanges(graph);
 
-        final List<Vertex> nearest = graph.findNearestVertices("embedding", 
new float[]{1.0f, 0.0f, 0.0f}, 2);
+        final List<TinkerVertex> nearest = 
graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
         assertNotNull(nearest);
         assertEquals(2, nearest.size());
         assertEquals("Alice", nearest.get(0).value("name"));
@@ -132,7 +133,7 @@ public class TinkerGraphVectorIndexTest {
         g.V().has("name", "Bob").drop().iterate();
         tryCommitChanges(graph);
 
-        final List<Vertex> nearest = graph.findNearestVertices("embedding", 
new float[]{0.0f, 1.0f, 0.0f}, 2);
+        final List<TinkerVertex> nearest = 
graph.findNearestVerticesOnly("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2);
         assertNotNull(nearest);
         assertEquals(2, nearest.size());
         assertThat(nearest.stream().noneMatch(v -> 
v.value("name").equals("Bob")), is(true));
@@ -153,7 +154,7 @@ public class TinkerGraphVectorIndexTest {
         assertThat(graph.getIndexedKeys(Vertex.class).contains("embedding"), 
is(false));
 
         try {
-            graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 
0.0f}, 2);
+            graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 
0.0f}, 2);
             fail("Should have thrown exception since the index was removed");
         } catch (IllegalArgumentException ex) { }
     }
@@ -173,7 +174,7 @@ public class TinkerGraphVectorIndexTest {
 
         graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, 
indexConfig);
 
-        final List<Edge> nearest = graph.findNearestEdges("embedding", new 
float[]{1.0f, 0.0f, 0.0f}, 2);
+        final List<TinkerEdge> nearest = 
graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
         assertNotNull(nearest);
         assertEquals(2, nearest.size());
         assertEquals(0.8f, (float) nearest.get(0).value("strength"), 0.0001f);
@@ -195,7 +196,7 @@ public class TinkerGraphVectorIndexTest {
         g.E(edge.id()).property("embedding", new float[]{0.9f, 0.1f, 
0.0f}).iterate();
         tryCommitChanges(graph);
 
-        final List<Edge> nearest = graph.findNearestEdges("embedding", new 
float[]{1.0f, 0.0f, 0.0f}, 2);
+        final List<TinkerEdge> nearest = 
graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
         assertNotNull(nearest);
         assertEquals(2, nearest.size());
         assertEquals(0.8f, (float) nearest.get(0).value("strength"), 0.0001f);
@@ -219,7 +220,7 @@ public class TinkerGraphVectorIndexTest {
         g.E(edge.id()).drop().iterate();
         tryCommitChanges(graph);
 
-        final List<Edge> nearest = graph.findNearestEdges("embedding", new 
float[]{0.0f, 1.0f, 0.0f}, 2);
+        final List<TinkerEdge> nearest = 
graph.findNearestEdgesOnly("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2);
         assertNotNull(nearest);
         assertEquals(2, nearest.size());
         assertThat(nearest.stream().noneMatch(e -> 
e.value("strength").equals(0.6f)), is(true));
@@ -242,7 +243,7 @@ public class TinkerGraphVectorIndexTest {
         assertThat(graph.getIndexedKeys(Edge.class).contains("embedding"), 
is(false));
 
         try {
-            graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 
2);
+            graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 
0.0f}, 2);
             fail("Should have thrown exception since the index was removed");
         } catch (IllegalArgumentException ex) { }
     }
@@ -285,7 +286,7 @@ public class TinkerGraphVectorIndexTest {
         tryRollbackChanges(graph);
 
         // Bob's embedding should still be [0.0f, 1.0f, 0.0f]
-        final List<Vertex> nearest = graph.findNearestVertices("embedding", 
new float[]{0.0f, 1.0f, 0.0f}, 1);
+        final List<TinkerVertex> nearest = 
graph.findNearestVerticesOnly("embedding", new float[]{0.0f, 1.0f, 0.0f}, 1);
         assertNotNull(nearest);
         assertEquals(1, nearest.size());
         assertEquals("Bob", nearest.get(0).value("name"));
@@ -294,7 +295,7 @@ public class TinkerGraphVectorIndexTest {
     @Test
     public void shouldHandleEmptyGraphForNearestVertices() {
         graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
indexConfig);
-        final List<Vertex> nearest = graph.findNearestVertices("embedding", 
new float[]{1.0f, 0.0f, 0.0f}, 2);
+        final List<TinkerVertex> nearest = 
graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
         assertNotNull(nearest);
         assertEquals(0, nearest.size());
     }
@@ -302,31 +303,208 @@ public class TinkerGraphVectorIndexTest {
     @Test
     public void shouldHandleEmptyGraphForNearestEdges() {
         graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, 
indexConfig);
-        final List<Edge> nearest = graph.findNearestEdges("embedding", new 
float[]{1.0f, 0.0f, 0.0f}, 2);
+        final List<TinkerEdge> nearest = 
graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
         assertNotNull(nearest);
         assertEquals(0, nearest.size());
     }
 
     @Test(expected = IllegalStateException.class)
     public void shouldThrowExceptionWhenIndexNotCreatedForNearestVertices() {
-        graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 
2);
+        graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 
0.0f}, 2);
     }
 
     @Test(expected = IllegalStateException.class)
     public void shouldThrowExceptionWhenIndexNotCreatedForNearestEdges() {
-        graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
+        graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f}, 
2);
     }
 
     @Test(expected = IllegalStateException.class)
     public void 
shouldThrowExceptionWhenIndexNotCreatedForNearestVerticesNoDefaultCount() {
-        graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f});
+        graph.findNearestVerticesOnly("embedding", new float[]{1.0f, 0.0f, 
0.0f});
     }
 
     @Test(expected = IllegalStateException.class)
     public void 
shouldThrowExceptionWhenIndexNotCreatedForNearestEdgesNoDefaultCount() {
+        graph.findNearestEdgesOnly("embedding", new float[]{1.0f, 0.0f, 0.0f});
+    }
+
+    @Test
+    public void shouldFindNearestVerticesWithDefaultK() {
+        final GraphTraversalSource g = traversal().with(graph);
+        g.addV("person").property("name", "Alice").property("embedding", new 
float[]{1.0f, 0.0f, 0.0f}).iterate();
+        g.addV("person").property("name", "Bob").property("embedding", new 
float[]{0.0f, 1.0f, 0.0f}).iterate();
+        g.addV("person").property("name", "Charlie").property("embedding", new 
float[]{0.0f, 0.0f, 1.0f}).iterate();
+        g.addV("person").property("name", "Dave").property("embedding", new 
float[]{0.9f, 0.1f, 0.0f}).iterate();
+
+        tryCommitChanges(graph);
+
+        graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
indexConfig);
+
+        final List<TinkerIndexElement<TinkerVertex>> nearest = 
graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f});
+        assertNotNull(nearest);
+        assertEquals(4, nearest.size());
+
+        // Sort by distance first, then by "strength" to ensure deterministic 
order
+        nearest.sort((e1, e2) -> {
+            int distanceComparison = Float.compare(e1.getDistance(), 
e2.getDistance());
+            if (distanceComparison != 0) return distanceComparison;
+            return 
e1.getElement().value("name").toString().compareTo(e2.getElement().value("name"));
+        });
+
+        assertEquals("Alice", nearest.get(0).getElement().value("name"));
+        assertEquals("Dave", nearest.get(1).getElement().value("name"));
+        assertEquals("Bob", nearest.get(2).getElement().value("name"));
+        assertEquals("Charlie", nearest.get(3).getElement().value("name"));
+
+        // ensure that the finds are descending order given distance
+        for (int i = 0; i < nearest.size() - 1; i++) {
+            assertThat(nearest.get(i).getDistance(), 
is(lessThanOrEqualTo(nearest.get(i + 1).getDistance())));
+        }
+    }
+
+    @Test
+    public void shouldFindNearestVerticesWithSpecifiedK() {
+        final GraphTraversalSource g = traversal().with(graph);
+        g.addV("person").property("name", "Alice").property("embedding", new 
float[]{1.0f, 0.0f, 0.0f}).iterate();
+        g.addV("person").property("name", "Bob").property("embedding", new 
float[]{0.0f, 1.0f, 0.0f}).iterate();
+        g.addV("person").property("name", "Charlie").property("embedding", new 
float[]{0.0f, 0.0f, 1.0f}).iterate();
+        g.addV("person").property("name", "Dave").property("embedding", new 
float[]{0.9f, 0.1f, 0.0f}).iterate();
+
+        tryCommitChanges(graph);
+
+        graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
indexConfig);
+
+        final List<TinkerIndexElement<TinkerVertex>> nearest = 
graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
+        assertNotNull(nearest);
+        assertEquals(2, nearest.size());
+        assertEquals("Alice", nearest.get(0).getElement().value("name"));
+        assertEquals("Dave", nearest.get(1).getElement().value("name"));
+
+        // ensure that the finds are descending order given distance
+        for (int i = 0; i < nearest.size() - 1; i++) {
+            assertThat(nearest.get(i).getDistance(), 
is(lessThanOrEqualTo(nearest.get(i + 1).getDistance())));
+        }
+    }
+
+    @Test(expected = IllegalStateException.class)
+    public void 
shouldThrowExceptionWhenIndexNotCreatedForFindNearestVertices() {
+        graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f});
+    }
+
+    @Test(expected = IllegalStateException.class)
+    public void 
shouldThrowExceptionWhenIndexNotCreatedForFindNearestVerticesWithK() {
+        graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 
2);
+    }
+
+    @Test
+    public void shouldHandleEmptyGraphForFindNearestVertices() {
+        graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
indexConfig);
+        final List<TinkerIndexElement<TinkerVertex>> nearest = 
graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f});
+        assertNotNull(nearest);
+        assertEquals(0, nearest.size());
+    }
+
+    @Test
+    public void shouldHandleEmptyGraphForFindNearestVerticesWithK() {
+        graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
indexConfig);
+        final List<TinkerIndexElement<TinkerVertex>> nearest = 
graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
+        assertNotNull(nearest);
+        assertEquals(0, nearest.size());
+    }
+
+    @Test
+    public void shouldFindNearestEdgesWithDefaultK() {
+        final GraphTraversalSource g = traversal().with(graph);
+        final Vertex alice = g.addV("person").property("name", "Alice").next();
+        final Vertex bob = g.addV("person").property("name", "Bob").next();
+        final Vertex charlie = g.addV("person").property("name", 
"Charlie").next();
+        final Vertex dave = g.addV("person").property("name", "Dave").next();
+        g.addE("knows").from(alice).to(bob).property("embedding", new 
float[]{1.0f, 0.0f, 0.0f}).property("strength", 8).iterate();
+        g.addE("knows").from(bob).to(charlie).property("embedding", new 
float[]{0.0f, 1.0f, 0.0f}).property("strength", 6).iterate();
+        g.addE("knows").from(charlie).to(dave).property("embedding", new 
float[]{0.0f, 0.0f, 1.0f}).property("strength", 7).iterate();
+        g.addE("knows").from(alice).to(dave).property("embedding", new 
float[]{0.9f, 0.1f, 0.0f}).property("strength", 9).iterate();
+
+        tryCommitChanges(graph);
+
+        graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, 
indexConfig);
+
+        final List<TinkerIndexElement<TinkerEdge>> nearest = 
graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f});
+        assertNotNull(nearest);
+        assertEquals(4, nearest.size());
+
+        // Sort by distance first, then by "strength" to ensure deterministic 
order
+        nearest.sort((e1, e2) -> {
+            int distanceComparison = Float.compare(e1.getDistance(), 
e2.getDistance());
+            if (distanceComparison != 0) return distanceComparison;
+            return Integer.compare((int) e1.getElement().value("strength"), 
(int) e2.getElement().value("strength"));
+        });
+
+        // Assert the sorted results
+        assertEquals(8, (int) nearest.get(0).getElement().value("strength"));
+        assertEquals(9, (int) nearest.get(1).getElement().value("strength"));
+        assertEquals(6, (int) nearest.get(2).getElement().value("strength"));
+        assertEquals(7, (int) nearest.get(3).getElement().value("strength"));
+
+        // Validate distances are in non-decreasing order
+        for (int i = 0; i < nearest.size() - 1; i++) {
+            assertThat(nearest.get(i).getDistance(), 
is(lessThanOrEqualTo(nearest.get(i + 1).getDistance())));
+        }
+    }
+
+    @Test
+    public void shouldFindNearestEdgesWithSpecifiedK() {
+        final GraphTraversalSource g = traversal().with(graph);
+        final Vertex alice = g.addV("person").property("name", "Alice").next();
+        final Vertex bob = g.addV("person").property("name", "Bob").next();
+        final Vertex charlie = g.addV("person").property("name", 
"Charlie").next();
+        final Vertex dave = g.addV("person").property("name", "Dave").next();
+        g.addE("knows").from(alice).to(bob).property("embedding", new 
float[]{1.0f, 0.0f, 0.0f}).property("strength", 8).iterate();
+        g.addE("knows").from(bob).to(charlie).property("embedding", new 
float[]{0.0f, 1.0f, 0.0f}).property("strength", 6).iterate();
+        g.addE("knows").from(charlie).to(dave).property("embedding", new 
float[]{0.0f, 0.0f, 1.0f}).property("strength", 7).iterate();
+        g.addE("knows").from(alice).to(dave).property("embedding", new 
float[]{0.9f, 0.1f, 0.0f}).property("strength", 9).iterate();
+
+        tryCommitChanges(graph);
+
+        graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, 
indexConfig);
+
+        final List<TinkerIndexElement<TinkerEdge>> nearest = 
graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
+        assertNotNull(nearest);
+        assertEquals(2, nearest.size());
+        assertEquals(8, (int) nearest.get(0).getElement().value("strength"));
+        assertEquals(9, (int) nearest.get(1).getElement().value("strength"));
+
+        // Validate distances are in non-decreasing order
+        for (int i = 0; i < nearest.size() - 1; i++) {
+            assertThat(nearest.get(i).getDistance(), 
is(lessThanOrEqualTo(nearest.get(i + 1).getDistance())));
+        }
+    }
+
+    @Test(expected = IllegalStateException.class)
+    public void shouldThrowExceptionWhenIndexNotCreatedForFindNearestEdges() {
         graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f});
     }
 
+    @Test(expected = IllegalStateException.class)
+    public void 
shouldThrowExceptionWhenIndexNotCreatedForFindNearestEdgesWithK() {
+        graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
+    }
+
+    @Test
+    public void shouldHandleEmptyGraphForFindNearestEdges() {
+        graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, 
indexConfig);
+        final List<TinkerIndexElement<TinkerEdge>> nearest = 
graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f});
+        assertNotNull(nearest);
+        assertEquals(0, nearest.size());
+    }
+
+    @Test
+    public void shouldHandleEmptyGraphForFindNearestEdgesWithK() {
+        graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, 
indexConfig);
+        final List<TinkerIndexElement<TinkerEdge>> nearest = 
graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2);
+        assertNotNull(nearest);
+        assertEquals(0, nearest.size());
+    }
+
     private void tryCommitChanges(final Graph graph) {
         if (graph.features().graph().supportsTransactions())
             graph.tx().commit();


Reply via email to