This is an automated email from the ASF dual-hosted git repository. spmallette pushed a commit to branch TINKERPOP-3158 in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 1b29a4c9f6131b044bfe9a1679df2bb394a70515 Author: Stephen Mallette <[email protected]> AuthorDate: Tue Jun 2 17:57:55 2026 -0400 Switch TinkerGraph vector index from hnswlib-core to JVector 3.0.6 Replaces com.github.jelmerk:hnswlib-core with io.github.jbellis:jvector:3.0.6 (stable release, Apache 2.0). JVector uses a DiskANN+HNSW hybrid algorithm, has no fixed capacity (removing the maxItems/growthRate/resize machinery), and uses the Panama Vector API on Java 20+ for SIMD without requiring any code change. Supported distance functions are now COSINE (default), EUCLIDEAN, and INNER_PRODUCT. MANHATTAN, BRAY_CURTIS, CANBERRA, and CORRELATION are removed. Distance values returned by the index are normalized to [0,1] lower-is-better via 1-similarity rather than raw library-specific values. The maxItems, growthRate, and ef configuration keys are removed. Explicit dimension validation is added so mismatched-dimension vectors throw IllegalArgumentException as before. Excludes jvector's transitive snakeyaml pull to satisfy the TinkerPop convergence enforcer. Bumps gremlin-benchmark JMH to 1.37 and adds commons-math3 exclusion in spark-gremlin to resolve remaining convergence conflicts introduced by jvector. (tinkerpop-1w2) Assisted-by: Claude Code:claude-sonnet-4-6 --- CHANGELOG.asciidoc | 2 +- .../reference/implementations-tinkergraph.asciidoc | 28 +- gremlin-tools/gremlin-benchmark/pom.xml | 2 +- spark-gremlin/pom.xml | 6 + tinkergraph-gremlin/pom.xml | 12 +- .../services/TinkerVectorDistanceFactory.java | 6 +- .../structure/AbstractTinkerVectorIndex.java | 4 - .../tinkergraph/structure/TinkerIndexType.java | 78 +++-- .../structure/TinkerTransactionVectorIndex.java | 387 +++++++-------------- .../tinkergraph/structure/TinkerVectorIndex.java | 358 ++++++------------- .../structure/TinkerGraphServiceTest.java | 22 +- .../structure/TinkerGraphVectorIndexTest.java | 79 +---- 12 files changed, 325 insertions(+), 659 deletions(-) diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc index e09a5fa514..e34cd89ea9 100644 --- a/CHANGELOG.asciidoc +++ b/CHANGELOG.asciidoc @@ -25,7 +25,7 @@ image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima [[release-4-0-0]] === TinkerPop 4.0.0 (Release Date: NOT OFFICIALLY RELEASED YET) -* Added vector indexing to TinkerGraph with search services for `tinker.search.vector.topKByElement` and `tinker.search.vector.topKByEmbedding`. +* Added vector indexing to TinkerGraph with search services for `tinker.search.vector.topK.byElement` and `tinker.search.vector.topK.byEmbedding`. * Added vector distance calculation functions for TinkerGraph. * Renamed the regex based search service in TinkerGraph to `tinker.search.text`. * Added typed numeric wrappers and `preciseNumbers` connection option to `gremlin-javascript` for explicit control over numeric type serialization and deserialization. diff --git a/docs/src/reference/implementations-tinkergraph.asciidoc b/docs/src/reference/implementations-tinkergraph.asciidoc index d61d9eee86..84a63e35dd 100644 --- a/docs/src/reference/implementations-tinkergraph.asciidoc +++ b/docs/src/reference/implementations-tinkergraph.asciidoc @@ -327,15 +327,11 @@ The `call()` step returns a list of maps, each containing: TIP: Vector indices can also be created for edges. -TinkerGraph supports various distance functions for vector similarity search: +TinkerGraph supports the following distance functions for vector similarity search: * `COSINE`: Measures the cosine of the angle between two vectors (default) * `EUCLIDEAN`: Measures the straight-line distance between two points -* `MANHATTAN`: Measures the sum of absolute differences between coordinates -* `INNER_PRODUCT`: Measures the dot product of two vectors -* `BRAY_CURTIS`: Measures the Bray-Curtis dissimilarity -* `CANBERRA`: Measures the weighted Manhattan distance -* `CORRELATION`: Measures the correlation distance +* `INNER_PRODUCT`: Measures the inner product of two vectors You can specify the distance function when creating the vector index: @@ -352,33 +348,27 @@ These options are specified when creating a vector index: |========================================================= |Configuration Option |Description |Default Value |`dimension` |The dimension of the vector embeddings. This is a required parameter and must match the length of the vector embeddings stored in the graph. |N/A (Required) -|`distanceFunction` |The distance function to use for similarity calculations. Must be one of the `TinkerIndexType.Vector` enum values (COSINE, EUCLIDEAN, MANHATTAN, INNER_PRODUCT, BRAY_CURTIS, CANBERRA, CORRELATION). |COSINE -|`growthRate`| The rate at which the index will automatically increase in size once it is full. If set to `0` the index will not grow automatically and will throw `SizeLimitExceededException` when its maximum size is reached. |0.10 -|`m` |The maximum number of connections per node in the HNSW graph. Higher values provide better search quality at the cost of increased memory usage and index build time. |16 +|`distanceFunction` |The distance function to use for similarity calculations. Must be one of the `TinkerIndexType.Vector` enum values (COSINE, EUCLIDEAN, INNER_PRODUCT). |COSINE +|`m` |The maximum number of connections per node in the graph. Higher values provide better search quality at the cost of increased memory usage and index build time. |16 |`efConstruction` |The size of the dynamic candidate list during index construction. Higher values improve index quality at the cost of longer build times. |200 -|`ef` |The size of the dynamic candidate list during search. Higher values improve search accuracy at the cost of slower search times. |10 -|`maxItems` |The maximum number of items expected to be stored in the index. Use in conjuction with `growthRate`. |10000 |========================================================= Here's an example of creating a vector index with custom configuration options: [source,groovy] ---- -graph.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graph)) +graph.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graph)) indexConfig = [ dimension : 128, - distance : TinkerIndexType.Vector.COSINE, - growthRate : 0.15, + distanceFunction: TinkerIndexType.Vector.COSINE, m : 32, - efConstruction : 300, - ef : 20, - maxItems : 1000 + efConstruction : 300 ] graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig) ---- -TIP: Constants for all the configuration values can be found in `TinkerVectorIndex`. They are prefixed with "CONFIG_". -For example, "dimension" can be referenced as `TinkerVectorIndex.CONFIG_DIMENSION`. +TIP: Constants for all the configuration values can be found in `AbstractTinkerVectorIndex`. They are prefixed with "CONFIG_". +For example, "dimension" can be referenced as `AbstractTinkerVectorIndex.CONFIG_DIMENSION`. Note that the distance functions can be used directly with the `TinkerVectorDistanceFactory` service. It allows calculation of the distance between the starting and ending elements in a `Path`. diff --git a/gremlin-tools/gremlin-benchmark/pom.xml b/gremlin-tools/gremlin-benchmark/pom.xml index c2537a8e1f..e2384259b7 100644 --- a/gremlin-tools/gremlin-benchmark/pom.xml +++ b/gremlin-tools/gremlin-benchmark/pom.xml @@ -27,7 +27,7 @@ limitations under the License. <artifactId>gremlin-benchmark</artifactId> <name>Apache TinkerPop :: Gremlin Benchmark</name> <properties> - <jmh.version>1.36</jmh.version> + <jmh.version>1.37</jmh.version> <!-- Skip benchmarks by default because they are time consuming. --> <skipBenchmarks>true</skipBenchmarks> <skipTests>${skipBenchmarks}</skipTests> diff --git a/spark-gremlin/pom.xml b/spark-gremlin/pom.xml index d80e0f3e16..c54156c5c3 100644 --- a/spark-gremlin/pom.xml +++ b/spark-gremlin/pom.xml @@ -294,6 +294,12 @@ limitations under the License. <artifactId>tinkergraph-gremlin</artifactId> <version>${project.version}</version> <scope>test</scope> + <exclusions> + <exclusion> + <groupId>org.apache.commons</groupId> + <artifactId>commons-math3</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>ch.qos.logback</groupId> diff --git a/tinkergraph-gremlin/pom.xml b/tinkergraph-gremlin/pom.xml index 024fce3a03..e9ee973294 100644 --- a/tinkergraph-gremlin/pom.xml +++ b/tinkergraph-gremlin/pom.xml @@ -36,9 +36,15 @@ limitations under the License. <artifactId>commons-lang3</artifactId> </dependency> <dependency> - <groupId>com.github.jelmerk</groupId> - <artifactId>hnswlib-core</artifactId> - <version>1.2.1</version> + <groupId>io.github.jbellis</groupId> + <artifactId>jvector</artifactId> + <version>3.0.6</version> + <exclusions> + <exclusion> + <groupId>org.yaml</groupId> + <artifactId>snakeyaml</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>com.google.inject</groupId> diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorDistanceFactory.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorDistanceFactory.java index 5c00cbbdfb..a8ee79e657 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorDistanceFactory.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorDistanceFactory.java @@ -18,7 +18,6 @@ */ package org.apache.tinkerpop.gremlin.tinkergraph.services; -import com.github.jelmerk.hnswlib.core.DistanceFunction; import org.apache.tinkerpop.gremlin.process.traversal.Path; import org.apache.tinkerpop.gremlin.process.traversal.Traverser; import org.apache.tinkerpop.gremlin.structure.Element; @@ -104,21 +103,18 @@ public class TinkerVectorDistanceFactory extends TinkerServiceRegistry.TinkerSer final TinkerIndexType.Vector vector = TinkerIndexType.Vector.valueOf( params.getOrDefault(Params.DISTANCE_FUNCTION, TinkerIndexType.Vector.COSINE).toString()); - final DistanceFunction<float[], Float> distanceFunction = vector.getDistanceFunction(); final Path path = in.get(); final int pathLength = path.size(); final Element start = path.get(0); final Element end = path.get(pathLength - 1); - // if the elements do not have the specified key, then return no results because there's nothing we can - // calculate distance on if (!start.keys().contains(key) || !end.keys().contains(key)) return CloseableIterator.empty(); final float[] startEmbedding = start.value(key); final float[] endEmbedding = end.value(key); - return CloseableIterator.of(Collections.singleton(distanceFunction.distance(startEmbedding, endEmbedding)).iterator()); + return CloseableIterator.of(Collections.singleton(vector.distance(startEmbedding, endEmbedding)).iterator()); } @Override diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java index 51cc1d4487..603e26f39a 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java @@ -33,16 +33,12 @@ abstract class AbstractTinkerVectorIndex<T extends Element> extends AbstractTink static final int DEFAULT_M = 16; static final int DEFAULT_EF_CONSTRUCTION = 200; static final int DEFAULT_EF = 10; - static final int DEFAULT_MAX_ITEMS = 10000; - static final double DEFAULT_GROWTH_RATE = 0.1; public static final String CONFIG_DIMENSION = "dimension"; public static final String CONFIG_M = "m"; public static final String CONFIG_EF_CONSTRUCTION = "efConstruction"; public static final String CONFIG_EF = "ef"; - public static final String CONFIG_MAX_ITEMS = "maxItems"; public static final String CONFIG_DISTANCE_FUNCTION = "distanceFunction"; - public static final String CONFIG_GROWTH_RATE = "growthRate"; protected AbstractTinkerVectorIndex(final AbstractTinkerGraph graph, final Class<T> indexClass) { super(graph, indexClass); diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexType.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexType.java index 84004345e2..90cc356dd4 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexType.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexType.java @@ -18,8 +18,7 @@ */ package org.apache.tinkerpop.gremlin.tinkergraph.structure; -import com.github.jelmerk.hnswlib.core.DistanceFunction; -import com.github.jelmerk.hnswlib.core.DistanceFunctions; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; /** * Enum for the different types of indices supported by TinkerGraph @@ -38,29 +37,66 @@ public enum TinkerIndexType { /** * Distance functions for vector index. */ - public enum Vector implements VectorDistance<float[], Float> { + public enum Vector { - COSINE(DistanceFunctions.FLOAT_COSINE_DISTANCE), - EUCLIDEAN(DistanceFunctions.FLOAT_EUCLIDEAN_DISTANCE), - MANHATTAN(DistanceFunctions.FLOAT_MANHATTAN_DISTANCE), - INNER_PRODUCT(DistanceFunctions.FLOAT_INNER_PRODUCT), - BRAY_CURTIS(DistanceFunctions.FLOAT_BRAY_CURTIS_DISTANCE), - CANBERRA(DistanceFunctions.FLOAT_CANBERRA_DISTANCE), - CORRELATION(DistanceFunctions.FLOAT_CORRELATION_DISTANCE); + COSINE { + @Override + public VectorSimilarityFunction toJVectorFunction() { + return VectorSimilarityFunction.COSINE; + } - private final DistanceFunction<float[], Float> distanceFunction; + @Override + public float distance(final float[] a, final float[] b) { + float dot = 0, normA = 0, normB = 0; + for (int i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + final float denom = (float) (Math.sqrt(normA) * Math.sqrt(normB)); + return denom == 0 ? 1.0f : 1.0f - (dot / denom); + } + }, + EUCLIDEAN { + @Override + public VectorSimilarityFunction toJVectorFunction() { + return VectorSimilarityFunction.EUCLIDEAN; + } - Vector(final DistanceFunction<float[], Float> distanceFunction) { - this.distanceFunction = distanceFunction; - } + @Override + public float distance(final float[] a, final float[] b) { + float sum = 0; + for (int i = 0; i < a.length; i++) { + final float d = a[i] - b[i]; + sum += d * d; + } + return (float) Math.sqrt(sum); + } + }, + INNER_PRODUCT { + @Override + public VectorSimilarityFunction toJVectorFunction() { + return VectorSimilarityFunction.DOT_PRODUCT; + } - @Override - public DistanceFunction<float[], Float> getDistanceFunction() { - return distanceFunction; - } - } + @Override + public float distance(final float[] a, final float[] b) { + float dot = 0; + for (int i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + } + return 1.0f - dot; + } + }; + + /** + * Returns the corresponding JVector similarity function for index construction and search. + */ + public abstract VectorSimilarityFunction toJVectorFunction(); - interface VectorDistance<V, T> { - DistanceFunction<V, T> getDistanceFunction(); + /** + * Computes the distance between two vectors. Lower values indicate greater similarity. + */ + public abstract float distance(final float[] a, final float[] b); } } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java index 4a651bf37f..5f984536c6 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java @@ -18,16 +18,21 @@ */ package org.apache.tinkerpop.gremlin.tinkergraph.structure; -import com.github.jelmerk.hnswlib.core.Item; -import com.github.jelmerk.hnswlib.core.SearchResult; -import com.github.jelmerk.hnswlib.core.hnsw.HnswIndex; -import com.github.jelmerk.hnswlib.core.Index; -import com.github.jelmerk.hnswlib.core.hnsw.SizeLimitExceededException; +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.SearchResult; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; import org.apache.tinkerpop.gremlin.structure.Graph; import org.apache.tinkerpop.gremlin.structure.Property; import org.apache.tinkerpop.gremlin.structure.Vertex; -import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -36,38 +41,20 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; /** - * A vector index implementation for TinkerTransactionGraph using hnswlib. + * A vector index implementation for TinkerTransactionGraph using JVector. * * @param <T> the element type (Vertex or Edge) */ final class TinkerTransactionVectorIndex<T extends TinkerElement> extends AbstractTinkerVectorIndex<T> { - /** - * Map of property key to vector index - */ - protected Map<String, Index<Object, float[], ElementItem, Float>> vectorIndices = new ConcurrentHashMap<>(); + private static final VectorTypeSupport VTS = VectorizationProvider.getInstance().getVectorTypeSupport(); - /** - * Map of property key to growth rate - */ - private final Map<String, Double> growthRates = new ConcurrentHashMap<>(); + protected Map<String, IndexState<T>> vectorIndices = new ConcurrentHashMap<>(); - /** - * Creates a new vector index for the specified graph and element class. - * - * @param graph the graph - * @param indexClass the element class - */ - public TinkerTransactionVectorIndex(final TinkerTransactionGraph graph, final Class<T> indexClass) { + TinkerTransactionVectorIndex(final TinkerTransactionGraph graph, final Class<T> indexClass) { super(graph, indexClass); } - /** - * Creates a vector index for the specified property key with the given configuration options. - * - * @param key the property key - * @param configuration the configuration options - */ @Override public void createIndex(final String key, final Map<String, Object> configuration) { if (null == key) @@ -75,18 +62,13 @@ final class TinkerTransactionVectorIndex<T extends TinkerElement> extends Abstra if (key.isEmpty()) throw new IllegalArgumentException("The key for the index cannot be an empty string"); - // Get dimension from configuration or throw exception if not provided if (!configuration.containsKey(CONFIG_DIMENSION)) throw new IllegalArgumentException("The dimension must be provided in the configuration"); - final int dimension; final Object dimObj = configuration.get(CONFIG_DIMENSION); - if (dimObj instanceof Number) { - dimension = ((Number) dimObj).intValue(); - } else { + if (!(dimObj instanceof Number)) throw new IllegalArgumentException("The dimension must be a number"); - } - + final int dimension = ((Number) dimObj).intValue(); if (dimension <= 0) throw new IllegalArgumentException("The dimension must be greater than 0"); @@ -95,317 +77,192 @@ final class TinkerTransactionVectorIndex<T extends TinkerElement> extends Abstra this.indexedKeys.add(key); int m = DEFAULT_M; - if (configuration.containsKey(CONFIG_M)) { - final Object mObj = configuration.get(CONFIG_M); - if (mObj instanceof Number) { - m = ((Number) mObj).intValue(); - } - } + if (configuration.containsKey(CONFIG_M) && configuration.get(CONFIG_M) instanceof Number) + m = ((Number) configuration.get(CONFIG_M)).intValue(); int efConstruction = DEFAULT_EF_CONSTRUCTION; - if (configuration.containsKey(CONFIG_EF_CONSTRUCTION)) { - final Object efObj = configuration.get(CONFIG_EF_CONSTRUCTION); - if (efObj instanceof Number) { - efConstruction = ((Number) efObj).intValue(); - } - } + if (configuration.containsKey(CONFIG_EF_CONSTRUCTION) && configuration.get(CONFIG_EF_CONSTRUCTION) instanceof Number) + efConstruction = ((Number) configuration.get(CONFIG_EF_CONSTRUCTION)).intValue(); - int ef = DEFAULT_EF; - if (configuration.containsKey(CONFIG_EF)) { - final Object efObj = configuration.get(CONFIG_EF); - if (efObj instanceof Number) { - ef = ((Number) efObj).intValue(); - } - } + TinkerIndexType.Vector distFunc = TinkerIndexType.Vector.COSINE; + if (configuration.containsKey(CONFIG_DISTANCE_FUNCTION) && configuration.get(CONFIG_DISTANCE_FUNCTION) instanceof TinkerIndexType.Vector) + distFunc = (TinkerIndexType.Vector) configuration.get(CONFIG_DISTANCE_FUNCTION); - int maxItems = DEFAULT_MAX_ITEMS; - if (configuration.containsKey(CONFIG_MAX_ITEMS)) { - final Object maxObj = configuration.get(CONFIG_MAX_ITEMS); - if (maxObj instanceof Number) { - maxItems = ((Number) maxObj).intValue(); - } - } + final IndexState<T> state = new IndexState<>(dimension, m, efConstruction, distFunc.toJVectorFunction()); + this.vectorIndices.put(key, state); - TinkerIndexType.Vector vector = TinkerIndexType.Vector.COSINE; - if (configuration.containsKey(CONFIG_DISTANCE_FUNCTION)) { - final Object vec = configuration.get(CONFIG_DISTANCE_FUNCTION); - if (vec instanceof TinkerIndexType.Vector) { - vector = ((TinkerIndexType.Vector) vec); - } - } - - double growthRate = DEFAULT_GROWTH_RATE; - if (configuration.containsKey(CONFIG_GROWTH_RATE)) { - final Object growthObj = configuration.get(CONFIG_GROWTH_RATE); - if (growthObj instanceof Number) { - growthRate = ((Number) growthObj).doubleValue(); - } - } - this.growthRates.put(key, growthRate); - - // Create a new HNSW index for this property key - final Index<Object, float[], ElementItem, Float> index = HnswIndex - .newBuilder(dimension, vector.getDistanceFunction(), Float::compare, maxItems) - .withM(m) - .withEfConstruction(efConstruction) - .withEf(ef) - .withRemoveEnabled() - .build(); - - this.vectorIndices.put(key, index); - - // Index existing elements - final Map elements = + final Map<?, ?> elementMap = Vertex.class.isAssignableFrom(indexClass) ? ((TinkerTransactionGraph) graph).getVertices() : ((TinkerTransactionGraph) graph).getEdges(); - for (Object element : elements.values()) { - TinkerElementContainer container = (TinkerElementContainer) element; - Object e = container.get(); + for (final Object raw : elementMap.values()) { + final TinkerElementContainer container = (TinkerElementContainer) raw; + final Object e = container.get(); if (e != null && indexClass.isInstance(e)) { - T tinkerElement = (T) e; - Property property = tinkerElement.property(key); - if (property.isPresent()) { - Object value = property.value(); - if (value instanceof float[] && ((float[]) value).length == dimension) { - this.addToIndex(key, (float[]) value, tinkerElement); - } + final T element = (T) e; + final Property<?> prop = element.property(key); + if (prop.isPresent() && prop.value() instanceof float[]) { + final float[] v = (float[]) prop.value(); + if (v.length == dimension) + state.add(element, v); } } } } - /** - * Adds an element with a vector to the index. - * - * @param key the property key - * @param vector the vector - * @param element the element - */ public void addToIndex(final String key, final float[] vector, final T element) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + final IndexState<T> state = this.vectorIndices.get(key); + if (state == null) return; - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - final ElementItem item = new ElementItem(element.id(), vector, element); - index.add(item); + state.add(element, vector); } - /** - * Searches for nearest neighbors in the vector index. - * - * @param key the property key - * @param vector the query vector - * @param k the number of nearest neighbors to return - * @return a list of elements sorted by distance - */ + @Override public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + final IndexState<T> state = this.vectorIndices.get(key); + if (state == null) throw new IllegalArgumentException("The key '" + key + "' is not indexed"); - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); - return nearest.stream().map(sr -> - new TinkerIndexElement<>(sr.item().element, sr.distance())).collect(Collectors.toList()); + return state.search(vector, k); } - /** - * Searches for nearest neighbors in the vector index. - * - * @param key the property key - * @param vector the query vector - * @param k the number of nearest neighbors to return - * @return a list of elements sorted by distance - */ + @Override public List<T> findNearestElements(final String key, final float[] vector, final int k) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) - throw new IllegalArgumentException("The key '" + key + "' is not indexed"); - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); - return nearest.stream().map(sr -> sr.item().element).collect(Collectors.toList()); + return findNearest(key, vector, k).stream() + .map(TinkerIndexElement::getElement) + .collect(Collectors.toList()); } - /** - * Removes an element from the vector index. - * - * @param key the property key - * @param element the element - */ public void removeFromIndex(final String key, final T element) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + final IndexState<T> state = this.vectorIndices.get(key); + if (state == null) return; - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - try { - index.remove(element.id(), 0); - } catch (Exception e) { - // If the element is not in the index, just ignore the exception - } + state.remove(element); } - /** - * Updates the vector index when an element's property changes. - * - * @param key the property key - * @param newValue the new vector value - * @param element the element - */ public void updateIndex(final String key, final float[] newValue, final T element) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + final IndexState<T> state = this.vectorIndices.get(key); + if (state == null) return; - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - try { - index.remove(element.id(), 0); - } catch (Exception e) { - // If the element is not in the index, just ignore the exception - } - final ElementItem item = new ElementItem(element.id(), newValue, element); - addWithResize(key, index, item); + state.remove(element); + state.add(element, newValue); } - /** - * Drops the vector index for the specified property key. - * - * @param key the property key - */ @Override public void dropIndex(final String key) { - if (this.vectorIndices.containsKey(key)) { - this.vectorIndices.remove(key); - } - + final IndexState<T> state = this.vectorIndices.remove(key); + if (state != null) + state.close(); this.indexedKeys.remove(key); } - /** - * A class that wraps an element with its vector for use in the HNSW index. - */ - private class ElementItem implements Item<Object, float[]>, Serializable { - private final Object id; - private final float[] vector; - private final T element; - - public ElementItem(final Object id, final float[] vector, final T element) { - this.id = id; - this.vector = vector; - this.element = element; - } - - @Override - public Object id() { - return id; - } - - @Override - public float[] vector() { - return vector; - } - - @Override - public int dimensions() { - return vector.length; - } - } - - // AbstractTinkerIndex implementation methods - @Override public List<T> get(final String key, final Object value) { - // This method is for regular indices, not vector indices return Collections.emptyList(); } @Override public long count(final String key, final Object value) { - // This method is for regular indices, not vector indices return 0; } @Override public void remove(final String key, final Object value, final T element) { - // only make changes to index tx close + // index changes only applied on tx commit } @Override public void removeElement(final T element) { - // only make changes to index tx close + // index changes only applied on tx commit } @Override public void autoUpdate(final String key, final Object newValue, final Object oldValue, final T element) { - // only make changes to index tx close + // index changes only applied on tx commit } - /** - * Commit changes to the index. - * - * @param updatedElements the set of updated elements - */ public void commit(final Set<TinkerElementContainer> updatedElements) { for (final TinkerElementContainer container : updatedElements) { - Object element = container.get(); + final Object element = container.get(); if (element != null && !container.isDeleted() && indexClass.isInstance(element)) { - T tinkerElement = (T) element; - for (String key : this.indexedKeys) { - Property property = tinkerElement.property(key); - if (property.isPresent() && property.value() instanceof float[]) { - updateIndex(key, (float[]) property.value(), tinkerElement); - } + final T tinkerElement = (T) element; + for (final String key : this.indexedKeys) { + final Property<?> prop = tinkerElement.property(key); + if (prop.isPresent() && prop.value() instanceof float[]) + updateIndex(key, (float[]) prop.value(), tinkerElement); } } else if (container.isDeleted()) { - Object oldElement = container.getUnmodified(); + final Object oldElement = container.getUnmodified(); if (oldElement != null && indexClass.isInstance(oldElement)) { - T tinkerOldElement = (T) oldElement; - for (String key : this.indexedKeys) { + final T tinkerOldElement = (T) oldElement; + for (final String key : this.indexedKeys) removeFromIndex(key, tinkerOldElement); - } } } } } - /** - * Rollback changes to the index. - */ public void rollback() { - // No specific action needed for rollback in the current implementation + // no action needed; index changes are deferred to commit } /** - * Helper method to add an item to the index with automatic resizing if needed. - * - * @param key the property key - * @param index the vector index - * @param item the item to add + * Per-key index state backed by JVector. */ - private void addWithResize(final String key, final Index<Object, float[], ElementItem, Float> index, - final ElementItem item) { - try { - index.add(item); - } catch (SizeLimitExceededException e) { - // Get the growth rate for this index - final Double growthRate = this.growthRates.getOrDefault(key, 0.0d); - - // If growth rate is 0 or not set, rethrow the exception - if (growthRate <= 0) { - throw e; - } + private static final class IndexState<T extends TinkerElement> { + private final int dimension; + private final VectorSimilarityFunction similarityFunction; + private final List<VectorFloat<?>> vectors = new ArrayList<>(); + private final List<T> elements = new ArrayList<>(); + private final Map<Object, Integer> idToOrdinal = new ConcurrentHashMap<>(); + private final GraphIndexBuilder builder; + + IndexState(final int dimension, final int m, final int efConstruction, + final VectorSimilarityFunction similarityFunction) { + this.dimension = dimension; + this.similarityFunction = similarityFunction; + final ListRandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); + this.builder = new GraphIndexBuilder(ravv, similarityFunction, m, efConstruction, 1.4f, 1.2f); + } - // Calculate new size based on growth rate - final int currentSize = ((HnswIndex<Object, float[], ElementItem, Float>) index).getMaxItemCount(); - final int newSize = currentSize + (int) Math.ceil(currentSize * growthRate); + synchronized void add(final T element, final float[] vector) { + if (vector.length != dimension) + throw new IllegalArgumentException( + "Vector dimension " + vector.length + " does not match index dimension " + dimension); + final int ordinal = vectors.size(); + final VectorFloat<?> vf = VTS.createFloatVector(vector); + vectors.add(vf); + elements.add(element); + idToOrdinal.put(element.id(), ordinal); + builder.addGraphNode(ordinal, vf); + } - // Resize the index - ((HnswIndex<Object, float[], ElementItem, Float>) index).resize(newSize); + synchronized void remove(final T element) { + final Integer ordinal = idToOrdinal.remove(element.id()); + if (ordinal != null) + builder.markNodeDeleted(ordinal); + } + + List<TinkerIndexElement<T>> search(final float[] queryVector, final int k) { + if (vectors.isEmpty()) + return Collections.emptyList(); + final VectorFloat<?> query = VTS.createFloatVector(queryVector); + final ListRandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); + final var ssp = SearchScoreProvider.exact(query, similarityFunction, ravv); + try (final GraphSearcher searcher = new GraphSearcher(builder.getGraph())) { + final SearchResult result = searcher.search(ssp, k, io.github.jbellis.jvector.util.Bits.ALL); + return Arrays.stream(result.getNodes()) + .map(ns -> new TinkerIndexElement<>(elements.get(ns.node), 1.0f - ns.score)) + .collect(Collectors.toList()); + } catch (Exception e) { + throw new RuntimeException("Vector search failed", e); + } + } - // Try adding the item again - index.add(item); - } catch (Exception e) { - // If it's not a size limit exception, rethrow it - throw e; + void close() { + try { + builder.close(); + } catch (Exception ignored) {} } } } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java index b470264dab..0b6382ad7f 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java @@ -18,17 +18,21 @@ */ package org.apache.tinkerpop.gremlin.tinkergraph.structure; -import com.github.jelmerk.hnswlib.core.Item; -import com.github.jelmerk.hnswlib.core.SearchResult; -import com.github.jelmerk.hnswlib.core.hnsw.HnswIndex; -import com.github.jelmerk.hnswlib.core.hnsw.SizeLimitExceededException; -import com.github.jelmerk.hnswlib.core.Index; +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.SearchResult; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; import org.apache.tinkerpop.gremlin.structure.Element; import org.apache.tinkerpop.gremlin.structure.Graph; import org.apache.tinkerpop.gremlin.structure.Property; import org.apache.tinkerpop.gremlin.structure.Vertex; -import java.io.Serializable; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -36,38 +40,20 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; /** - * A vector index implementation for TinkerGraph using hnswlib. + * A vector index implementation for TinkerGraph using JVector. * * @param <T> the element type (Vertex or Edge) */ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorIndex<T> { - /** - * Map of the property key to vector index - */ - private final Map<String, Index<Object, float[], ElementItem, Float>> vectorIndices = new ConcurrentHashMap<>(); + private static final VectorTypeSupport VTS = VectorizationProvider.getInstance().getVectorTypeSupport(); - /** - * Map of property key to growth rate - */ - private final Map<String, Double> growthRates = new ConcurrentHashMap<>(); + private final Map<String, IndexState<T>> vectorIndices = new ConcurrentHashMap<>(); - /** - * Creates a new vector index for the specified graph and element class. - * - * @param graph the graph - * @param indexClass the element class - */ - public TinkerVectorIndex(final TinkerGraph graph, final Class<T> indexClass) { + TinkerVectorIndex(final TinkerGraph graph, final Class<T> indexClass) { super(graph, indexClass); } - /** - * Creates a vector index for the specified property key with the given configuration options. - * - * @param key the property key - * @param configuration the configuration options - */ @Override public void createIndex(final String key, final Map<String, Object> configuration) { if (null == key) @@ -75,18 +61,13 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorInd if (key.isEmpty()) throw new IllegalArgumentException("The key for the index cannot be an empty string"); - // Get dimension from configuration or throw exception if not provided if (!configuration.containsKey(CONFIG_DIMENSION)) throw new IllegalArgumentException("The dimension must be provided in the configuration"); - final int dimension; final Object dimObj = configuration.get(CONFIG_DIMENSION); - if (dimObj instanceof Number) { - dimension = ((Number) dimObj).intValue(); - } else { + if (!(dimObj instanceof Number)) throw new IllegalArgumentException("The dimension must be a number"); - } - + final int dimension = ((Number) dimObj).intValue(); if (dimension <= 0) throw new IllegalArgumentException("The dimension must be greater than 0"); @@ -95,290 +76,165 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorInd this.indexedKeys.add(key); int m = DEFAULT_M; - if (configuration.containsKey(CONFIG_M)) { - final Object mObj = configuration.get(CONFIG_M); - if (mObj instanceof Number) { - m = ((Number) mObj).intValue(); - } - } + if (configuration.containsKey(CONFIG_M) && configuration.get(CONFIG_M) instanceof Number) + m = ((Number) configuration.get(CONFIG_M)).intValue(); int efConstruction = DEFAULT_EF_CONSTRUCTION; - if (configuration.containsKey(CONFIG_EF_CONSTRUCTION)) { - final Object efObj = configuration.get(CONFIG_EF_CONSTRUCTION); - if (efObj instanceof Number) { - efConstruction = ((Number) efObj).intValue(); - } - } + if (configuration.containsKey(CONFIG_EF_CONSTRUCTION) && configuration.get(CONFIG_EF_CONSTRUCTION) instanceof Number) + efConstruction = ((Number) configuration.get(CONFIG_EF_CONSTRUCTION)).intValue(); - int ef = DEFAULT_EF; - if (configuration.containsKey(CONFIG_EF)) { - final Object efObj = configuration.get(CONFIG_EF); - if (efObj instanceof Number) { - ef = ((Number) efObj).intValue(); - } - } + TinkerIndexType.Vector distFunc = TinkerIndexType.Vector.COSINE; + if (configuration.containsKey(CONFIG_DISTANCE_FUNCTION) && configuration.get(CONFIG_DISTANCE_FUNCTION) instanceof TinkerIndexType.Vector) + distFunc = (TinkerIndexType.Vector) configuration.get(CONFIG_DISTANCE_FUNCTION); - int maxItems = DEFAULT_MAX_ITEMS; - if (configuration.containsKey(CONFIG_MAX_ITEMS)) { - final Object maxObj = configuration.get(CONFIG_MAX_ITEMS); - if (maxObj instanceof Number) { - maxItems = ((Number) maxObj).intValue(); - } - } + final IndexState<T> state = new IndexState<>(dimension, m, efConstruction, distFunc.toJVectorFunction()); + this.vectorIndices.put(key, state); - TinkerIndexType.Vector vector = TinkerIndexType.Vector.COSINE; - if (configuration.containsKey(CONFIG_DISTANCE_FUNCTION)) { - final Object vec = configuration.get(CONFIG_DISTANCE_FUNCTION); - if (vec instanceof TinkerIndexType.Vector) { - vector = ((TinkerIndexType.Vector) vec); - } - } - - double growthRate = DEFAULT_GROWTH_RATE; - if (configuration.containsKey(CONFIG_GROWTH_RATE)) { - final Object growthObj = configuration.get(CONFIG_GROWTH_RATE); - if (growthObj instanceof Number) { - growthRate = ((Number) growthObj).doubleValue(); - } - } - this.growthRates.put(key, growthRate); - - // Create a new HNSW index for this property key - final Index<Object, float[], ElementItem, Float> index = HnswIndex - .newBuilder(dimension, vector.getDistanceFunction(), Float::compare, maxItems) - .withM(m) - .withEfConstruction(efConstruction) - .withEf(ef) - .withRemoveEnabled() - .build(); - - this.vectorIndices.put(key, index); - - // Index existing elements (Vertex.class.isAssignableFrom(this.indexClass) ? ((TinkerGraph) this.graph).vertices.values().parallelStream() : ((TinkerGraph) this.graph).edges.values().parallelStream()) - .map(e -> new Object[]{((T) e).property(key), e}) - .filter(a -> ((Property) a[0]).isPresent()) - .forEach(a -> { - // values for the key that don't match the dimensions of the index won't be added - final Object value = ((Property) a[0]).value(); - if (value instanceof float[] && ((float[]) value).length == dimension) { - this.addToIndex(key, (float[]) value, (T) a[1]); + .forEach(e -> { + final Property<?> prop = ((T) e).property(key); + if (prop.isPresent() && prop.value() instanceof float[]) { + final float[] v = (float[]) prop.value(); + if (v.length == dimension) + state.add((T) e, v); } }); } - /** - * Adds an element with a vector to the index. - * - * @param key the property key - * @param vector the vector - * @param element the element - */ public void addToIndex(final String key, final float[] vector, final T element) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + final IndexState<T> state = this.vectorIndices.get(key); + if (state == null) return; - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - final ElementItem item = new ElementItem(element.id(), vector, element); - - addWithResize(key, index, item); + state.add(element, vector); } - /** - * Searches for nearest neighbors in the vector index. - * - * @param key the property key - * @param vector the query vector - * @param k the number of nearest neighbors to return - * @return a list of elements sorted by distance - */ + @Override public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + final IndexState<T> state = this.vectorIndices.get(key); + if (state == null) throw new IllegalArgumentException("The key '" + key + "' is not indexed"); - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); - return nearest.stream().map(sr -> - new TinkerIndexElement<>(sr.item().element, sr.distance())).collect(Collectors.toList()); + return state.search(vector, k); } - /** - * Searches for nearest neighbors in the vector index. - * - * @param key the property key - * @param vector the query vector - * @param k the number of nearest neighbors to return - * @return a list of elements sorted by distance - */ + @Override public List<T> findNearestElements(final String key, final float[] vector, final int k) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) - throw new IllegalArgumentException("The key '" + key + "' is not indexed"); - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); - return nearest.stream().map(sr -> sr.item().element).collect(Collectors.toList()); + return findNearest(key, vector, k).stream() + .map(TinkerIndexElement::getElement) + .collect(Collectors.toList()); } - /** - * Removes an element from the vector index. - * - * @param key the property key - * @param element the element - */ public void removeFromIndex(final String key, final T element) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + final IndexState<T> state = this.vectorIndices.get(key); + if (state == null) return; - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - try { - index.remove(element.id(), 0); - } catch (Exception e) { - // If the element is not in the index, just ignore the exception - } + state.remove(element); } - /** - * Updates the vector index when an element's property changes. - * - * @param key the property key - * @param newValue the new vector value - * @param element the element - */ public void updateIndex(final String key, final float[] newValue, final T element) { - if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + final IndexState<T> state = this.vectorIndices.get(key); + if (state == null) return; - - final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); - try { - index.remove(element.id(), 0); - } catch (Exception e) { - // If the element is not in the index, just ignore the exception - } - final ElementItem item = new ElementItem(element.id(), newValue, element); - - addWithResize(key, index, item); + state.remove(element); + state.add(element, newValue); } - /** - * Drops the vector index for the specified property key. - * - * @param key the property key - */ @Override public void dropIndex(final String key) { - if (this.vectorIndices.containsKey(key)) { - this.vectorIndices.remove(key); - } - - if (this.growthRates.containsKey(key)) { - this.growthRates.remove(key); - } - + final IndexState<T> state = this.vectorIndices.remove(key); + if (state != null) + state.close(); this.indexedKeys.remove(key); } - // AbstractTinkerIndex implementation methods - @Override public List<T> get(final String key, final Object value) { - // This method is for regular indices, not vector indices return Collections.emptyList(); } @Override public long count(final String key, final Object value) { - // This method is for regular indices, not vector indices return 0; } @Override public void remove(final String key, final Object value, final T element) { - // For vector indices, we use removeFromIndex - if (value instanceof float[]) { + if (value instanceof float[]) removeFromIndex(key, element); - } } @Override public void removeElement(final T element) { if (this.indexClass.isAssignableFrom(element.getClass())) { - for (String key : this.indexedKeys) { + for (final String key : this.indexedKeys) removeFromIndex(key, element); - } } } @Override public void autoUpdate(final String key, final Object newValue, final Object oldValue, final T element) { - if (this.indexedKeys.contains(key) && newValue instanceof float[]) { + if (this.indexedKeys.contains(key) && newValue instanceof float[]) updateIndex(key, (float[]) newValue, element); - } } /** - * Helper method to add an item to the index with automatic resizing if needed. - * - * @param key the property key - * @param index the vector index - * @param item the item to add + * Per-key index state backed by JVector. */ - private void addWithResize(final String key, final Index<Object, float[], ElementItem, Float> index, - final ElementItem item) { - try { - index.add(item); - } catch (SizeLimitExceededException e) { - // Get the growth rate for this index - final Double growthRate = this.growthRates.getOrDefault(key, 0.0d); - - // If growth rate is 0 or not set, rethrow the exception - if (growthRate <= 0) { - throw e; + private static final class IndexState<T extends Element> { + private final int dimension; + private final VectorSimilarityFunction similarityFunction; + private final List<VectorFloat<?>> vectors = new ArrayList<>(); + private final List<T> elements = new ArrayList<>(); + private final Map<Object, Integer> idToOrdinal = new ConcurrentHashMap<>(); + private final GraphIndexBuilder builder; + + IndexState(final int dimension, final int m, final int efConstruction, + final VectorSimilarityFunction similarityFunction) { + this.dimension = dimension; + this.similarityFunction = similarityFunction; + final ListRandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); + this.builder = new GraphIndexBuilder(ravv, similarityFunction, m, efConstruction, 1.4f, 1.2f); + } + + synchronized void add(final T element, final float[] vector) { + if (vector.length != dimension) + throw new IllegalArgumentException( + "Vector dimension " + vector.length + " does not match index dimension " + dimension); + final int ordinal = vectors.size(); + final VectorFloat<?> vf = VTS.createFloatVector(vector); + vectors.add(vf); + elements.add(element); + idToOrdinal.put(element.id(), ordinal); + builder.addGraphNode(ordinal, vf); + } + + synchronized void remove(final T element) { + final Integer ordinal = idToOrdinal.remove(element.id()); + if (ordinal != null) + builder.markNodeDeleted(ordinal); + } + + List<TinkerIndexElement<T>> search(final float[] queryVector, final int k) { + if (vectors.isEmpty()) + return Collections.emptyList(); + final VectorFloat<?> query = VTS.createFloatVector(queryVector); + final ListRandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); + final var ssp = SearchScoreProvider.exact(query, similarityFunction, ravv); + try (final GraphSearcher searcher = new GraphSearcher(builder.getGraph())) { + final SearchResult result = searcher.search(ssp, k, io.github.jbellis.jvector.util.Bits.ALL); + return java.util.Arrays.stream(result.getNodes()) + .map(ns -> new TinkerIndexElement<>(elements.get(ns.node), 1.0f - ns.score)) + .collect(Collectors.toList()); + } catch (Exception e) { + throw new RuntimeException("Vector search failed", e); } - - // Calculate new size based on growth rate - final int currentSize = ((HnswIndex<Object, float[], ElementItem, Float>) index).getMaxItemCount(); - final int newSize = currentSize + (int) Math.ceil(currentSize * growthRate); - - // Resize the index - ((HnswIndex<Object, float[], ElementItem, Float>) index).resize(newSize); - - // Try adding the item again - index.add(item); - } catch (Exception e) { - // If it's not a size limit exception, rethrow it - throw e; - } - } - - /** - * A class that wraps an element with its vector for use in the HNSW index. - */ - private class ElementItem implements Item<Object, float[]>, Serializable { - private final Object id; - private final float[] vector; - private final T element; - - public ElementItem(final Object id, final float[] vector, final T element) { - this.id = id; - this.vector = vector; - this.element = element; - } - - @Override - public Object id() { - return id; - } - - @Override - public float[] vector() { - return vector; } - @Override - public int dimensions() { - return vector.length; + void close() { + try { + builder.close(); + } catch (Exception ignored) {} } } } diff --git a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java index 2b062d430b..1843cb0990 100644 --- a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java +++ b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java @@ -471,9 +471,9 @@ public class TinkerGraphServiceTest { final List<Object> list = gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, m).toList(); final List<Map<String,Object>> expected = new ArrayList<>(); - expected.add(asMap("distance", 0.006116271f, "element", vDave)); - expected.add(asMap("distance", 0.9500624f, "element", vBob)); - expected.add(asMap("distance", 1.0f, "element", vCharlie)); + expected.add(asMap("distance", 0.0030581355f, "element", vDave)); + expected.add(asMap("distance", 0.4750312f, "element", vBob)); + expected.add(asMap("distance", 0.5f, "element", vCharlie)); // Use a custom comparison to ensure the lists are equal assertEquals(expected.size(), list.size()); @@ -504,8 +504,8 @@ public class TinkerGraphServiceTest { final List<Object> list = gv.E(e1).call(TinkerVectorSearchByElementFactory.NAME, m).toList(); final List<Map<String, Object>> expected = new ArrayList<>(); - expected.add(asMap("distance", 0.29289323f, "element", e3)); - expected.add(asMap("distance", 1.0f, "element", e2)); + expected.add(asMap("distance", 0.14644659f, "element", e3)); + expected.add(asMap("distance", 0.5f, "element", e2)); // Use a custom comparison to ensure the lists are equal assertEquals(expected.size(), list.size()); @@ -534,7 +534,7 @@ public class TinkerGraphServiceTest { final List<Object> list = gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, m).toList(); final List<Map<String,Object>> expected = new ArrayList<>(); - expected.add(asMap("distance", 0.006116271f, "element", vDave)); + expected.add(asMap("distance", 0.0030581355f, "element", vDave)); // Use a custom comparison to ensure the lists are equal assertEquals(expected.size(), list.size()); @@ -616,9 +616,9 @@ public class TinkerGraphServiceTest { final List<Map<String,Object>> expected = new ArrayList<>(); expected.add(asMap("distance", 0.0f, "element", vAlice)); - expected.add(asMap("distance", 0.006116271f, "element", vDave)); - expected.add(asMap("distance", 0.9500624f, "element", vBob)); - expected.add(asMap("distance", 1.0f, "element", vCharlie)); + expected.add(asMap("distance", 0.0030581355f, "element", vDave)); + expected.add(asMap("distance", 0.4750312f, "element", vBob)); + expected.add(asMap("distance", 0.5f, "element", vCharlie)); // Use a custom comparison to ensure the lists are equal assertEquals(expected.size(), list.size()); @@ -655,8 +655,8 @@ public class TinkerGraphServiceTest { final List<Map<String, Object>> expected = new ArrayList<>(); expected.add(asMap("distance", 0.0f, "element", e1)); - expected.add(asMap("distance", 0.29289323f, "element", e3)); - expected.add(asMap("distance", 1.0f, "element", e2)); + expected.add(asMap("distance", 0.14644659f, "element", e3)); + expected.add(asMap("distance", 0.5f, "element", e2)); // Use a custom comparison to ensure the lists are equal assertEquals(expected.size(), list.size()); diff --git a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java index b3fbfe8e80..5650834dc1 100644 --- a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java +++ b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java @@ -26,11 +26,8 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import com.github.jelmerk.hnswlib.core.hnsw.SizeLimitExceededException; - import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -48,9 +45,7 @@ import static org.junit.Assert.fail; @RunWith(Parameterized.class) public class TinkerGraphVectorIndexTest { - protected static final Map<String,Object> indexConfig = new HashMap<String,Object>() {{ - put(TinkerVectorIndex.CONFIG_DIMENSION, 3); - }}; + protected static final Map<String, Object> indexConfig = Map.of(TinkerVectorIndex.CONFIG_DIMENSION, 3); @Parameterized.Parameter public AbstractTinkerGraph graph; @@ -506,78 +501,6 @@ public class TinkerGraphVectorIndexTest { assertEquals(0, nearest.size()); } - @Test - public void shouldGrowIndexWhenCapacityReached() { - final GraphTraversalSource g = traversal().with(graph); - - // Create a small index with only 5 items capacity and 50% growth rate - final Map<String, Object> smallIndexConfig = new HashMap<>(indexConfig); - smallIndexConfig.put(TinkerVectorIndex.CONFIG_MAX_ITEMS, 5); - smallIndexConfig.put(TinkerVectorIndex.CONFIG_GROWTH_RATE, 0.5); // 50% growth - - graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, smallIndexConfig); - - // Add 5 vertices (fills the index to capacity) - for (int i = 0; i < 5; i++) { - g.addV("person").property("name", "Person" + i) - .property("embedding", new float[]{(float)i, 0.0f, 0.0f}).iterate(); - } - tryCommitChanges(graph); - - // Add one more vertex with a very distinctive vector - this should trigger the index growth - g.addV("person").property("name", "PersonExtra") - .property("embedding", new float[]{10.0f, 0.0f, 0.0f}).iterate(); - tryCommitChanges(graph); - - // Verify we can find all 6 vertices - final List<TinkerVertex> allVertices = graph.findNearestVerticesOnly("embedding", new float[]{0.0f, 0.0f, 0.0f}, 10); - assertEquals(6, allVertices.size()); - - // Verify the extra vertex exists in the results - boolean foundExtra = false; - for (TinkerVertex vertex : allVertices) { - if ("PersonExtra".equals(vertex.value("name"))) { - foundExtra = true; - break; - } - } - assertThat("Should find the extra vertex", foundExtra, is(true)); - } - - @Test - public void shouldThrowExceptionWhenGrowthRateIsZero() { - final GraphTraversalSource g = traversal().with(graph); - - // Create a small index with only 5 items capacity and 0 growth rate - final Map<String, Object> smallIndexConfig = new HashMap<>(indexConfig); - smallIndexConfig.put(TinkerVectorIndex.CONFIG_MAX_ITEMS, 5); - smallIndexConfig.put(TinkerVectorIndex.CONFIG_GROWTH_RATE, 0.0); // No growth - - graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, smallIndexConfig); - - // Add 5 vertices (fills the index to capacity) - for (int i = 0; i < 5; i++) { - g.addV("person").property("name", "Person" + i) - .property("embedding", new float[]{(float)i, 0.0f, 0.0f}).iterate(); - } - tryCommitChanges(graph); - - try { - // Add one more vertex - this should throw SizeLimitExceededException - g.addV("person").property("name", "PersonExtra") - .property("embedding", new float[]{5.0f, 0.0f, 0.0f}).iterate(); - tryCommitChanges(graph); - fail("Should have thrown SizeLimitExceededException"); - } catch (Exception e) { - // Verify that the exception is caused by SizeLimitExceededException - Throwable cause = e; - while (cause != null && !(cause instanceof SizeLimitExceededException)) { - cause = cause.getCause(); - } - assertNotNull("Expected SizeLimitExceededException", cause); - } - } - private void tryCommitChanges(final Graph graph) { if (graph.features().graph().supportsTransactions()) graph.tx().commit();
