This is an automated email from the ASF dual-hosted git repository.
adelapena pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git
The following commit(s) were added to refs/heads/trunk by this push:
new 5a82c04fd3 Add vector similarity functions
5a82c04fd3 is described below
commit 5a82c04fd363030d712fb2935b6c591577ba25ec
Author: Andrés de la Peña <[email protected]>
AuthorDate: Wed Jun 28 16:43:58 2023 +0100
Add vector similarity functions
patch by Andrés de la Peña; reviewed by Ekaterina Dimitrova and Maxwell Guo
for CASSANDRA-18640
Co-authored-by: Jonathan Ellis <[email protected]>
Co-authored-by: Andrés de la Peña <[email protected]>
---
CHANGES.txt | 1 +
NEWS.txt | 1 +
.../cassandra/pages/developing/cql/changes.adoc | 1 +
.../cassandra/pages/developing/cql/functions.adoc | 7 +
.../cassandra/partials/vector_functions.adoc | 41 +++++
src/java/org/apache/cassandra/cql3/CQL3Type.java | 14 +-
.../cassandra/cql3/functions/FunctionFactory.java | 14 +-
.../cql3/functions/FunctionParameter.java | 86 ++++++++--
.../cassandra/cql3/functions/NativeFunctions.java | 1 +
.../cassandra/cql3/functions/VectorFcts.java | 105 ++++++++++++
.../functions/masking/ReplaceMaskingFunction.java | 4 +-
.../cassandra/cql3/functions/types/VectorType.java | 2 +-
.../cassandra/cql3/selection/Selectable.java | 25 ++-
.../apache/cassandra/db/marshal/VectorType.java | 2 +-
.../cassandra/cql3/functions/VectorFctsTest.java | 179 +++++++++++++++++++++
.../cql3/validation/operations/CQLVectorTest.java | 4 +-
.../cassandra/utils/AbstractTypeGenerators.java | 6 +-
17 files changed, 453 insertions(+), 40 deletions(-)
diff --git a/CHANGES.txt b/CHANGES.txt
index 12691ced64..3f78b9769e 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
5.0
+ * Add vector similarity functions (CASSANDRA-18640)
* Lift MessagingService.minimum_version to 40 (CASSANDRA-18314)
* Introduce pluggable crypto providers and default to
AmazonCorrettoCryptoProvider (CASSANDRA-18624)
* Improved DeletionTime serialization (CASSANDRA-18648)
diff --git a/NEWS.txt b/NEWS.txt
index 1c5be180e8..378ae6e93c 100644
--- a/NEWS.txt
+++ b/NEWS.txt
@@ -84,6 +84,7 @@ New features
class. Further details and documentation can be found in
src/java/org/apache/cassandra/db/compaction/UnifiedCompactionStrategy.md
- New `VectorType` (cql `vector<element_type, dimension>`) which adds new
fixed-length element arrays. See CASSANDRA-18504
+ - Added new vector similarity functions `similarity_cosine`,
`similarity_euclidean` and `similarity_dot_product`.
- Removed UDT type migration logic for 3.6+ clusters upgrading to 4.0. If
migration has been disabled, it must be
enabled before upgrading to 5.0 if the cluster used UDTs. See
CASSANDRA-18504
- Entended max expiration time from 2038-01-19T03:14:06+00:00 to
2106-02-07T06:28:13+00:00
diff --git a/doc/modules/cassandra/pages/developing/cql/changes.adoc
b/doc/modules/cassandra/pages/developing/cql/changes.adoc
index 9dc3af00cb..7bf6bd6d03 100644
--- a/doc/modules/cassandra/pages/developing/cql/changes.adoc
+++ b/doc/modules/cassandra/pages/developing/cql/changes.adoc
@@ -4,6 +4,7 @@ The following describes the changes in each version of CQL.
== 3.4.7
+* Add vector similarity functions (`18640`)
* Remove deprecated functions `dateOf` and `unixTimestampOf`, replaced by
`toTimestamp` and `toUnixTimestamp` (`18328`)
* Added support for attaching masking functions to table columns (`18068`)
* Add UNMASK permission (`18069`)
diff --git a/doc/modules/cassandra/pages/developing/cql/functions.adoc
b/doc/modules/cassandra/pages/developing/cql/functions.adoc
index 6474b72cb5..244af3e7cf 100644
--- a/doc/modules/cassandra/pages/developing/cql/functions.adoc
+++ b/doc/modules/cassandra/pages/developing/cql/functions.adoc
@@ -284,6 +284,13 @@ A number of functions allow to obscure the real contents
of a column containing
include::partial$masking_functions.adoc[]
+[[vector-similarity-functions]]
+===== Vector similarity functions
+
+A number of functions allow to obtain the similarity score between vectors of
floats.
+
+include::partial$vector_functions.adoc[]
+
[[user-defined-scalar-functions]]
==== User-defined functions
diff --git a/doc/modules/cassandra/partials/vector_functions.adoc
b/doc/modules/cassandra/partials/vector_functions.adoc
new file mode 100644
index 0000000000..daa4b2b8ce
--- /dev/null
+++ b/doc/modules/cassandra/partials/vector_functions.adoc
@@ -0,0 +1,41 @@
+[cols=",",options="header",]
+|===
+|Function | Description
+
+| `similarity_cosine(vector, vector)` | Calculates the cosine similarity score
between two float vectors of the same dimension.
+
+Examples:
+
+`similarity_cosine([0.1, 0.2], null)` -> `null`
+
+`similarity_cosine([0.1, 0.2], [0.1, 0.2])` -> `1`
+
+`similarity_cosine([0.1, 0.2], [-0.1, -0.2])` -> `0`
+
+`similarity_cosine([0.1, 0.2], [0.9, 0.8])` -> `0.964238`
+
+| `similarity_euclidean(vector, vector)` | Calculates the euclidian distance
between two float vectors of the same dimension.
+
+Examples:
+
+`similarity_euclidean([0.1, 0.2], null)` -> `null`
+
+`similarity_euclidean([0.1, 0.2], [0.1, 0.2])` -> `1`
+
+`similarity_euclidean([0.1, 0.2], [-0.1, -0.2])` -> `0.833333`
+
+`similarity_euclidean([0.1, 0.2], [0.9, 0.8])` -> `0.5`
+
+| `similarity_dot_product(vector, vector)` | Calculates the dot product
between two float vectors of the same dimension.
+
+Examples:
+
+`similarity_dot_product([0.1, 0.2], null)` -> `null`
+
+`similarity_dot_product([0.1, 0.2], [0.1, 0.2])` -> `0.525`
+
+`similarity_dot_product([0.1, 0.2], [-0.1, -0.2])` -> `0.475`
+
+`similarity_dot_product([0.1, 0.2], [0.9, 0.8])` -> `0.625`
+
+|===
\ No newline at end of file
diff --git a/src/java/org/apache/cassandra/cql3/CQL3Type.java
b/src/java/org/apache/cassandra/cql3/CQL3Type.java
index 57474103bb..acb8bc97e2 100644
--- a/src/java/org/apache/cassandra/cql3/CQL3Type.java
+++ b/src/java/org/apache/cassandra/cql3/CQL3Type.java
@@ -834,13 +834,13 @@ public interface CQL3Type
private static class RawVector extends Raw
{
private final CQL3Type.Raw element;
- private final int dimention;
+ private final int dimension;
- private RawVector(Raw element, int dimention)
+ private RawVector(Raw element, int dimension)
{
super(true);
this.element = element;
- this.dimention = dimention;
+ this.dimension = dimension;
}
@Override
@@ -865,7 +865,13 @@ public interface CQL3Type
public CQL3Type prepare(String keyspace, Types udts) throws
InvalidRequestException
{
CQL3Type type = element.prepare(keyspace, udts);
- return new Vector(type.getType(), dimention);
+ return new Vector(type.getType(), dimension);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "vector<" + element.toString() + ", " + dimension + '>';
}
}
diff --git a/src/java/org/apache/cassandra/cql3/functions/FunctionFactory.java
b/src/java/org/apache/cassandra/cql3/functions/FunctionFactory.java
index 32db9fed20..90a2c69d77 100644
--- a/src/java/org/apache/cassandra/cql3/functions/FunctionFactory.java
+++ b/src/java/org/apache/cassandra/cql3/functions/FunctionFactory.java
@@ -90,9 +90,19 @@ public abstract class FunctionFactory
if (numArgs < numMandatoryParameters || numArgs > numParameters)
throw invalidNumberOfArgumentsException();
- // try to infer the types of the arguments
+ // Do a first pass trying to infer the types of the arguments
individually, without any context about the types
+ // of the other arguments. We don't do any validation during this
first pass.
List<AbstractType<?>> types = new ArrayList<>(args.size());
for (int i = 0; i < args.size(); i++)
+ {
+ AssignmentTestable arg = args.get(i);
+ FunctionParameter parameter = parameters.get(i);
+
types.add(parameter.inferType(SchemaConstants.SYSTEM_KEYSPACE_NAME, arg,
receiverType, null));
+ }
+
+ // Do a second pass trying to infer the types of the arguments
considering the types of other inferred types.
+ // We can validate the inferred types during this second pass.
+ for (int i = 0; i < args.size(); i++)
{
AssignmentTestable arg = args.get(i);
FunctionParameter parameter = parameters.get(i);
@@ -103,7 +113,7 @@ public abstract class FunctionFactory
arg, this));
parameter.validateType(name, arg, type);
type = type.udfType();
- types.add(type);
+ types.set(i, type);
}
return doGetOrCreateFunction(types, receiverType);
diff --git
a/src/java/org/apache/cassandra/cql3/functions/FunctionParameter.java
b/src/java/org/apache/cassandra/cql3/functions/FunctionParameter.java
index 083b0f61d4..708e4fdd7d 100644
--- a/src/java/org/apache/cassandra/cql3/functions/FunctionParameter.java
+++ b/src/java/org/apache/cassandra/cql3/functions/FunctionParameter.java
@@ -25,12 +25,14 @@ import javax.annotation.Nullable;
import org.apache.cassandra.cql3.AssignmentTestable;
import org.apache.cassandra.cql3.CQL3Type;
+import org.apache.cassandra.cql3.selection.Selectable;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.CollectionType;
import org.apache.cassandra.db.marshal.ListType;
import org.apache.cassandra.db.marshal.MapType;
import org.apache.cassandra.db.marshal.NumberType;
import org.apache.cassandra.db.marshal.SetType;
+import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import static java.lang.String.format;
@@ -48,13 +50,14 @@ public interface FunctionParameter
* @param keyspace the current keyspace
* @param arg a parameter value in a specific function call
* @param receiverType the type of the object that will receive the result
of the function call
+ * @param inferredTypes the types that have been inferred for the other
parameters
* @return the inferred data type of the parameter, or {@link null} it
isn't possible to infer it
*/
@Nullable
default AbstractType<?> inferType(String keyspace,
AssignmentTestable arg,
@Nullable AbstractType<?> receiverType,
- List<AbstractType<?>> previousTypes)
+ @Nullable List<AbstractType<?>>
inferredTypes)
{
return arg.getCompatibleTypeIfKnown(keyspace);
}
@@ -83,9 +86,9 @@ public interface FunctionParameter
public AbstractType<?> inferType(String keyspace,
AssignmentTestable arg,
@Nullable AbstractType<?>
receiverType,
- List<AbstractType<?>>
previousTypes)
+ @Nullable List<AbstractType<?>>
inferredTypes)
{
- return wrapped.inferType(keyspace, arg, receiverType,
previousTypes);
+ return wrapped.inferType(keyspace, arg, receiverType,
inferredTypes);
}
@Override
@@ -130,7 +133,7 @@ public interface FunctionParameter
public AbstractType<?> inferType(String keyspace,
AssignmentTestable arg,
@Nullable AbstractType<?>
receiverType,
- List<AbstractType<?>>
previousTypes)
+ @Nullable List<AbstractType<?>>
inferredTypes)
{
AbstractType<?> inferred =
arg.getCompatibleTypeIfKnown(keyspace);
return inferred != null ? inferred : types[0].getType();
@@ -168,7 +171,7 @@ public interface FunctionParameter
public AbstractType<?> inferType(String keyspace,
AssignmentTestable arg,
@Nullable AbstractType<?>
receiverType,
- List<AbstractType<?>>
previousTypes)
+ @Nullable List<AbstractType<?>>
inferredTypes)
{
AbstractType<?> type = arg.getCompatibleTypeIfKnown(keyspace);
return type == null && inferFromReceiver ? receiverType : type;
@@ -189,9 +192,12 @@ public interface FunctionParameter
}
/**
- * @return a function parameter definition that accepts values with the
same type as the first parameter
+ * @param index the index of the function argument that this parameter is
associated with
+ * @param preferOther whether the parameter should prefer the type of the
other parameter over its own type
+ * @param parameter the type of this parameter when the type of the
associated parameter is unknown
+ * @return a function parameter definition that is expected to have the
same type as another parameter
*/
- static FunctionParameter sameAsFirst()
+ static FunctionParameter sameAs(int index, boolean preferOther,
FunctionParameter parameter)
{
return new FunctionParameter()
{
@@ -199,21 +205,28 @@ public interface FunctionParameter
public AbstractType<?> inferType(String keyspace,
AssignmentTestable arg,
@Nullable AbstractType<?>
receiverType,
- List<AbstractType<?>>
previousTypes)
+ @Nullable List<AbstractType<?>>
inferredTypes)
{
- return previousTypes.get(0);
+ if (preferOther)
+ {
+ AbstractType<?> other = inferredTypes == null ? null :
inferredTypes.get(index);
+ return other == null ? parameter.inferType(keyspace, arg,
receiverType, inferredTypes) : other;
+ }
+
+ AbstractType<?> inferred = parameter.inferType(keyspace, arg,
receiverType, inferredTypes);
+ return inferred == null && inferredTypes != null ?
inferredTypes.get(index) : inferred;
}
@Override
public void validateType(FunctionName name, AssignmentTestable
arg, AbstractType<?> argType)
{
- // nothing to do here, all types are accepted
+ parameter.validateType(name, arg, argType);
}
@Override
public String toString()
{
- return "same";
+ return parameter.toString();
}
};
}
@@ -336,4 +349,55 @@ public interface FunctionParameter
}
};
}
+
+ /**
+ * @param type the type of the vector elements
+ * @return a function parameter definition that accepts values of type
{@link VectorType} with elements of the
+ * specified {@code type} and any dimensions.
+ */
+ static FunctionParameter vector(CQL3Type type)
+ {
+ return new FunctionParameter()
+ {
+ @Override
+ public AbstractType<?> inferType(String keyspace,
+ AssignmentTestable arg,
+ @Nullable AbstractType<?>
receiverType,
+ @Nullable List<AbstractType<?>>
inferredTypes)
+ {
+ if (arg instanceof Selectable.WithArrayLiteral)
+ return VectorType.getInstance(type.getType(),
((Selectable.WithArrayLiteral) arg).getSize());
+
+ AbstractType<?> inferred =
arg.getCompatibleTypeIfKnown(keyspace);
+ return inferred == null ? receiverType : inferred;
+ }
+
+ @Override
+ public void validateType(FunctionName name, AssignmentTestable
arg, AbstractType<?> argType)
+ {
+ if (argType.isVector())
+ {
+ VectorType<?> vectorType = (VectorType<?>) argType;
+ if (vectorType.elementType.asCQL3Type() == type)
+ return;
+ }
+ else if (argType instanceof ListType) // if it's terminal it
will be a list
+ {
+ ListType<?> listType = (ListType<?>) argType;
+ if
(listType.getElementsType().testAssignment(type.getType()) == NOT_ASSIGNABLE)
+ return;
+ }
+
+ throw new InvalidRequestException(format("Function %s requires
a %s vector argument, " +
+ "but found argument
%s of type %s",
+ name, type, arg,
argType.asCQL3Type()));
+ }
+
+ @Override
+ public String toString()
+ {
+ return format("vector<%s, n>", type);
+ }
+ };
+ }
}
diff --git a/src/java/org/apache/cassandra/cql3/functions/NativeFunctions.java
b/src/java/org/apache/cassandra/cql3/functions/NativeFunctions.java
index 02939e9896..2100fe3f89 100644
--- a/src/java/org/apache/cassandra/cql3/functions/NativeFunctions.java
+++ b/src/java/org/apache/cassandra/cql3/functions/NativeFunctions.java
@@ -46,6 +46,7 @@ public class NativeFunctions
BytesConversionFcts.addFunctionsTo(this);
MathFcts.addFunctionsTo(this);
MaskingFcts.addFunctionsTo(this);
+ VectorFcts.addFunctionsTo(this);
}
};
diff --git a/src/java/org/apache/cassandra/cql3/functions/VectorFcts.java
b/src/java/org/apache/cassandra/cql3/functions/VectorFcts.java
new file mode 100644
index 0000000000..31136dcfd4
--- /dev/null
+++ b/src/java/org/apache/cassandra/cql3/functions/VectorFcts.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.cql3.functions;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+
+import org.apache.cassandra.cql3.CQL3Type;
+import org.apache.cassandra.db.marshal.AbstractType;
+import org.apache.cassandra.db.marshal.FloatType;
+import org.apache.cassandra.db.marshal.VectorType;
+import org.apache.cassandra.exceptions.InvalidRequestException;
+import org.apache.cassandra.transport.ProtocolVersion;
+import org.apache.lucene.index.VectorSimilarityFunction;
+
+public class VectorFcts
+{
+ public static void addFunctionsTo(NativeFunctions functions)
+ {
+ functions.add(createSimilarityFunctionFactory("similarity_cosine",
VectorSimilarityFunction.COSINE, false));
+ functions.add(createSimilarityFunctionFactory("similarity_euclidean",
VectorSimilarityFunction.EUCLIDEAN, true));
+
functions.add(createSimilarityFunctionFactory("similarity_dot_product",
VectorSimilarityFunction.DOT_PRODUCT, true));
+ }
+
+ private static FunctionFactory createSimilarityFunctionFactory(String name,
+
VectorSimilarityFunction luceneFunction,
+ boolean
supportsZeroVectors)
+ {
+ return new FunctionFactory(name,
+ FunctionParameter.sameAs(1, false,
FunctionParameter.vector(CQL3Type.Native.FLOAT)),
+ FunctionParameter.sameAs(0, false,
FunctionParameter.vector(CQL3Type.Native.FLOAT)))
+ {
+ @Override
+ @SuppressWarnings("unchecked")
+ protected NativeFunction
doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?>
receiverType)
+ {
+ // check that all arguments have the same vector dimensions
+ VectorType<Float> firstArgType = (VectorType<Float>)
argTypes.get(0);
+ int dimensions = firstArgType.dimension;
+ if (!argTypes.stream().allMatch(t -> ((VectorType<?>)
t).dimension == dimensions))
+ throw new InvalidRequestException("All arguments must have
the same vector dimensions");
+ return createSimilarityFunction(name.name, firstArgType,
luceneFunction, supportsZeroVectors);
+ }
+ };
+ }
+
+ private static NativeFunction createSimilarityFunction(String name,
+ VectorType<Float>
type,
+
VectorSimilarityFunction f,
+ boolean
supportsZeroVectors)
+ {
+ return new NativeScalarFunction(name, FloatType.instance, type, type)
+ {
+ @Override
+ public Arguments newArguments(ProtocolVersion version)
+ {
+ return new FunctionArguments(version,
+ (v, b) -> type.composeAsFloat(b),
+ (v, b) -> type.composeAsFloat(b));
+ }
+
+ @Override
+ public ByteBuffer execute(Arguments arguments) throws
InvalidRequestException
+ {
+ if (arguments.containsNulls())
+ return null;
+
+ float[] v1 = arguments.get(0);
+ float[] v2 = arguments.get(1);
+
+ if (!supportsZeroVectors)
+ {
+ if (isAllZero(v1) || isAllZero(v2))
+ throw new InvalidRequestException("Function " + name +
" doesn't support all-zero vectors.");
+ }
+
+ return FloatType.instance.decompose(f.compare(v1, v2));
+ }
+
+ private boolean isAllZero(float[] v)
+ {
+ for (float f : v)
+ if (f != 0)
+ return false;
+ return true;
+ }
+ };
+ }
+}
diff --git
a/src/java/org/apache/cassandra/cql3/functions/masking/ReplaceMaskingFunction.java
b/src/java/org/apache/cassandra/cql3/functions/masking/ReplaceMaskingFunction.java
index 880f941abf..373743daf4 100644
---
a/src/java/org/apache/cassandra/cql3/functions/masking/ReplaceMaskingFunction.java
+++
b/src/java/org/apache/cassandra/cql3/functions/masking/ReplaceMaskingFunction.java
@@ -62,7 +62,9 @@ public class ReplaceMaskingFunction extends MaskingFunction
/** @return a {@link FunctionFactory} to build new {@link
ReplaceMaskingFunction}s. */
public static FunctionFactory factory()
{
- return new MaskingFunction.Factory(NAME,
FunctionParameter.anyType(true), FunctionParameter.sameAsFirst())
+ return new MaskingFunction.Factory(NAME,
+ FunctionParameter.anyType(true),
+ FunctionParameter.sameAs(0, true,
FunctionParameter.anyType(true)))
{
@Override
protected NativeFunction
doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?>
receiverType)
diff --git a/src/java/org/apache/cassandra/cql3/functions/types/VectorType.java
b/src/java/org/apache/cassandra/cql3/functions/types/VectorType.java
index 8d22322424..1c546aa78f 100644
--- a/src/java/org/apache/cassandra/cql3/functions/types/VectorType.java
+++ b/src/java/org/apache/cassandra/cql3/functions/types/VectorType.java
@@ -31,7 +31,7 @@ public class VectorType extends DataType
VectorType(DataType subtype, int dimensions)
{
super(Name.VECTOR);
- assert dimensions > 0 : "vectors may only have positive dimentions;
given " + dimensions;
+ assert dimensions > 0 : "vectors may only have positive dimensions;
given " + dimensions;
this.subtype = subtype;
this.dimensions = dimensions;
}
diff --git a/src/java/org/apache/cassandra/cql3/selection/Selectable.java
b/src/java/org/apache/cassandra/cql3/selection/Selectable.java
index 5c06da64e8..56f7f85025 100644
--- a/src/java/org/apache/cassandra/cql3/selection/Selectable.java
+++ b/src/java/org/apache/cassandra/cql3/selection/Selectable.java
@@ -726,7 +726,7 @@ public interface Selectable extends AssignmentTestable
/**
* The list elements
*/
- private final List<Selectable> selectables;
+ protected final List<Selectable> selectables;
public WithArrayLiteral(List<Selectable> selectables)
{
@@ -787,6 +787,11 @@ public interface Selectable extends AssignmentTestable
return Lists.listToString(selectables);
}
+ public int getSize()
+ {
+ return selectables.size();
+ }
+
public static class Raw implements Selectable.Raw
{
private final List<Selectable.Raw> raws;
@@ -804,16 +809,11 @@ public interface Selectable extends AssignmentTestable
}
}
- public static class WithList implements Selectable
+ public static class WithList extends WithArrayLiteral
{
- /**
- * The list elements
- */
- private final List<Selectable> selectables;
-
public WithList(List<Selectable> selectables)
{
- this.selectables = selectables;
+ super(selectables);
}
@Override
@@ -876,16 +876,11 @@ public interface Selectable extends AssignmentTestable
}
}
- public static class WithVector implements Selectable
+ public static class WithVector extends WithArrayLiteral
{
- /**
- * The vector elements
- */
- private final List<Selectable> selectables;
-
public WithVector(List<Selectable> selectables)
{
- this.selectables = selectables;
+ super(selectables);
}
@Override
diff --git a/src/java/org/apache/cassandra/db/marshal/VectorType.java
b/src/java/org/apache/cassandra/db/marshal/VectorType.java
index 631d990c6c..7caf02e3bb 100644
--- a/src/java/org/apache/cassandra/db/marshal/VectorType.java
+++ b/src/java/org/apache/cassandra/db/marshal/VectorType.java
@@ -84,7 +84,7 @@ public final class VectorType<T> extends AbstractType<List<T>>
{
super(ComparisonType.CUSTOM);
if (dimension <= 0)
- throw new InvalidRequestException(String.format("vectors may only
have positive dimentions; given %d", dimension));
+ throw new InvalidRequestException(String.format("vectors may only
have positive dimensions; given %d", dimension));
this.elementType = elementType;
this.dimension = dimension;
this.elementSerializer = elementType.getSerializer();
diff --git a/test/unit/org/apache/cassandra/cql3/functions/VectorFctsTest.java
b/test/unit/org/apache/cassandra/cql3/functions/VectorFctsTest.java
new file mode 100644
index 0000000000..207c00ea87
--- /dev/null
+++ b/test/unit/org/apache/cassandra/cql3/functions/VectorFctsTest.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.cql3.functions;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import org.apache.cassandra.cql3.CQLTester;
+import org.apache.lucene.index.VectorSimilarityFunction;
+
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+@RunWith(Parameterized.class)
+public class VectorFctsTest extends CQLTester
+{
+ @Parameterized.Parameter
+ public String function;
+
+ @Parameterized.Parameter(1)
+ public VectorSimilarityFunction luceneFunction;
+
+ @Parameterized.Parameters(name = "{index}: function={0}")
+ public static Collection<Object[]> data()
+ {
+ return Arrays.asList(new Object[][]{
+ { "system.similarity_cosine", VectorSimilarityFunction.COSINE },
+ { "system.similarity_euclidean", VectorSimilarityFunction.EUCLIDEAN },
+ { "system.similarity_dot_product",
VectorSimilarityFunction.DOT_PRODUCT }
+ });
+ }
+
+ @Test
+ public void testVectorSimilarityFunction()
+ {
+ createTable(KEYSPACE, "CREATE TABLE %s (pk int PRIMARY KEY, value
vector<float, 2>, " +
+ "l list<float>, " + // lists shouldn't be
accepted by the functions
+ "fl frozen<list<float>>, " + // frozen lists
shouldn't be accepted by the functions
+ "v1 vector<float, 1>, " + // 1-dimension vector
to test missmatching dimensions
+ "v_int vector<int, 2>, " + // int vectors
shouldn't be accepted by the functions
+ "v_double vector<double, 2>)");// double vectors
shouldn't be accepted by the functions
+
+ float[] values = new float[]{ 1f, 2f };
+ CQLTester.Vector<Float> vector = vector(ArrayUtils.toObject(values));
+ Object[] similarity = row(luceneFunction.compare(values, values));
+
+ // basic functionality
+ execute("INSERT INTO %s (pk, value, l, fl, v1, v_int, v_double) VALUES
(0, ?, ?, ?, ?, ?, ?)",
+ vector, list(1f, 2f), list(1f, 2f), vector(1f), vector(1, 2),
vector(1d, 2d));
+ assertRows(execute("SELECT " + function + "(value, value) FROM %s"),
similarity);
+
+ // literals
+ assertRows(execute("SELECT " + function + "(value, [1, 2]) FROM %s"),
similarity);
+ assertRows(execute("SELECT " + function + "([1, 2], value) FROM %s"),
similarity);
+ assertRows(execute("SELECT " + function + "([1, 2], [1, 2]) FROM %s"),
similarity);
+
+ // bind markers
+ assertRows(execute("SELECT " + function + "(value, ?) FROM %s",
vector), similarity);
+ assertRows(execute("SELECT " + function + "(?, value) FROM %s",
vector), similarity);
+ assertThatThrownBy(() -> execute("SELECT " + function + "(?, ?) FROM
%s", vector, vector))
+ .hasMessageContaining("Cannot infer type of argument ?");
+
+ // bind markers with type hints
+ assertRows(execute("SELECT " + function + "((vector<float, 2>) ?, ?)
FROM %s", vector, vector), similarity);
+ assertRows(execute("SELECT " + function + "(?, (vector<float, 2>) ?)
FROM %s", vector, vector), similarity);
+ assertRows(execute("SELECT " + function + "((vector<float, 2>) ?,
(vector<float, 2>) ?) FROM %s", vector, vector), similarity);
+
+ // bind markers and literals
+ assertRows(execute("SELECT " + function + "([1, 2], ?) FROM %s",
vector), similarity);
+ assertRows(execute("SELECT " + function + "(?, [1, 2]) FROM %s",
vector), similarity);
+ assertRows(execute("SELECT " + function + "([1, 2], ?) FROM %s",
vector), similarity);
+
+ // wrong column types with columns
+ assertThatThrownBy(() -> execute("SELECT " + function + "(l, value)
FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument l of type list<float>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "(fl, value)
FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument fl of type frozen<list<float>>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "(value, l)
FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument l of type list<float>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "(value, fl)
FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument fl of type frozen<list<float>>");
+
+ // wrong column types with columns and literals
+ assertThatThrownBy(() -> execute("SELECT " + function + "(l, [1, 2])
FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument l of type list<float>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "(fl, [1, 2])
FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument fl of type frozen<list<float>>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], l)
FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument l of type list<float>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], fl)
FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument fl of type frozen<list<float>>");
+
+ // wrong column types with cast literals
+ assertThatThrownBy(() -> execute("SELECT " + function +
"((List<Float>)[1, 2], [3, 4]) FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument (list<float>)[1, 2] of type
frozen<list<float>>");
+ assertThatThrownBy(() -> execute("SELECT " + function +
"((List<Float>)[1, 2], (List<Float>)[3, 4]) FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument (list<float>)[1, 2] of type
frozen<list<float>>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2],
(List<Float>)[3, 4]) FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument (list<float>)[3, 4] of type
frozen<list<float>>");
+
+ // wrong non-float vectors
+ assertThatThrownBy(() -> execute("SELECT " + function + "(v_int, [1,
2]) FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument v_int of type vector<int, 2>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "(v_double,
[1, 2]) FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument v_double of type vector<double, 2>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2],
v_int) FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument v_int of type vector<int, 2>");
+ assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2],
v_double) FROM %s"))
+ .hasMessageContaining("Function " + function + " requires a float
vector argument, but found argument v_double of type vector<double, 2>");
+
+ // mismatching dimensions with literals
+ assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], [3])
FROM %s", vector(1)))
+ .hasMessageContaining("All arguments must have the same vector
dimensions");
+ assertThatThrownBy(() -> execute("SELECT " + function + "(value, [1])
FROM %s", vector(1)))
+ .hasMessageContaining("All arguments must have the same vector
dimensions");
+ assertThatThrownBy(() -> execute("SELECT " + function + "([1], value)
FROM %s", vector(1)))
+ .hasMessageContaining("All arguments must have the same vector
dimensions");
+
+ // mismatching dimensions with bind markers
+ assertThatThrownBy(() -> execute("SELECT " + function +
"((vector<float, 1>) ?, value) FROM %s", vector(1)))
+ .hasMessageContaining("All arguments must have the same vector
dimensions");
+ assertThatThrownBy(() -> execute("SELECT " + function + "(value,
(vector<float, 1>) ?) FROM %s", vector(1)))
+ .hasMessageContaining("All arguments must have the same vector
dimensions");
+ assertThatThrownBy(() -> execute("SELECT " + function +
"((vector<float, 2>) ?, (vector<float, 1>) ?) FROM %s", vector(1, 2),
vector(1)))
+ .hasMessageContaining("All arguments must have the same vector
dimensions");
+
+ // mismatching dimensions with columns
+ assertThatThrownBy(() -> execute("SELECT " + function + "(value, v1)
FROM %s"))
+ .hasMessageContaining("All arguments must have the same vector
dimensions");
+ assertThatThrownBy(() -> execute("SELECT " + function + "(v1, value)
FROM %s"))
+ .hasMessageContaining("All arguments must have the same vector
dimensions");
+
+ // null arguments with literals
+ assertRows(execute("SELECT " + function + "(value, null) FROM %s"),
row((Float) null));
+ assertRows(execute("SELECT " + function + "(null, value) FROM %s"),
row((Float) null));
+ assertThatThrownBy(() -> execute("SELECT " + function + "(null, null)
FROM %s"))
+ .hasMessageContaining("Cannot infer type of argument NULL in call to
function " + function);
+
+ // null arguments with bind markers
+ assertRows(execute("SELECT " + function + "(value, ?) FROM %s",
(CQLTester.Vector<Float>) null), row((Float) null));
+ assertRows(execute("SELECT " + function + "(?, value) FROM %s",
(CQLTester.Vector<Float>) null), row((Float) null));
+ assertThatThrownBy(() -> execute("SELECT " + function + "(?, ?) FROM
%s", null, null))
+ .hasMessageContaining("Cannot infer type of argument ? in call to
function " + function);
+
+ // test all-zero vectors, only cosine similarity should reject them
+ if (luceneFunction == VectorSimilarityFunction.COSINE)
+ {
+ String expected = "Function " + function + " doesn't support
all-zero vectors";
+ assertThatThrownBy(() -> execute("SELECT " + function + "(value,
[0, 0]) FROM %s")) .hasMessageContaining(expected);
+ assertThatThrownBy(() -> execute("SELECT " + function + "([0, 0],
value) FROM %s")).hasMessageContaining(expected);
+ }
+ else
+ {
+ float expected = luceneFunction.compare(values, new float[]{ 0, 0
});
+ assertRows(execute("SELECT " + function + "(value, [0, 0]) FROM
%s"), row(expected));
+ assertRows(execute("SELECT " + function + "([0, 0], value) FROM
%s"), row(expected));
+ }
+ }
+}
diff --git
a/test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java
b/test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java
index 773a08cf38..cf88d32d93 100644
---
a/test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java
+++
b/test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java
@@ -457,14 +457,14 @@ public class CQLVectorTest extends CQLTester.InMemory
format("SELECT %s([1, 2, 3], [4, 5, 6]) FROM
%%s", f));
// Test wrong types on function creation
- assertInvalidThrowMessage("vectors may only have positive dimentions;
given 0",
+ assertInvalidThrowMessage("vectors may only have positive dimensions;
given 0",
InvalidRequestException.class,
"CREATE FUNCTION %s (x vector<int, 0>) " +
"CALLED ON NULL INPUT " +
"RETURNS vector<int, 2> " +
"LANGUAGE java " +
"AS 'return x;'");
- assertInvalidThrowMessage("vectors may only have positive dimentions;
given 0",
+ assertInvalidThrowMessage("vectors may only have positive dimensions;
given 0",
InvalidRequestException.class,
"CREATE FUNCTION %s (x vector<int, 2>) " +
"CALLED ON NULL INPUT " +
diff --git a/test/unit/org/apache/cassandra/utils/AbstractTypeGenerators.java
b/test/unit/org/apache/cassandra/utils/AbstractTypeGenerators.java
index d9e3f80508..0b23a51c5d 100644
--- a/test/unit/org/apache/cassandra/utils/AbstractTypeGenerators.java
+++ b/test/unit/org/apache/cassandra/utils/AbstractTypeGenerators.java
@@ -520,15 +520,15 @@ public final class AbstractTypeGenerators
return vectorTypeGen(typeGen, SourceDSL.integers().between(1, 100));
}
- public static Gen<VectorType<?>> vectorTypeGen(Gen<AbstractType<?>>
typeGen, Gen<Integer> dimentionGen)
+ public static Gen<VectorType<?>> vectorTypeGen(Gen<AbstractType<?>>
typeGen, Gen<Integer> dimensionGen)
{
return rnd -> {
- int dimention = dimentionGen.generate(rnd);
+ int dimension = dimensionGen.generate(rnd);
AbstractType<?> element = typeGen.generate(rnd);
// empty type not supported
while (element == EmptyType.instance)
element = typeGen.generate(rnd);
- return VectorType.getInstance(element.freeze(), dimention);
+ return VectorType.getInstance(element.freeze(), dimension);
};
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]