This is an automated email from the ASF dual-hosted git repository.

adelapena pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 5a82c04fd3 Add vector similarity functions
5a82c04fd3 is described below

commit 5a82c04fd363030d712fb2935b6c591577ba25ec
Author: Andrés de la Peña <[email protected]>
AuthorDate: Wed Jun 28 16:43:58 2023 +0100

    Add vector similarity functions
    
    patch by Andrés de la Peña; reviewed by Ekaterina Dimitrova and Maxwell Guo 
for CASSANDRA-18640
    
    Co-authored-by: Jonathan Ellis <[email protected]>
    Co-authored-by: Andrés de la Peña <[email protected]>
---
 CHANGES.txt                                        |   1 +
 NEWS.txt                                           |   1 +
 .../cassandra/pages/developing/cql/changes.adoc    |   1 +
 .../cassandra/pages/developing/cql/functions.adoc  |   7 +
 .../cassandra/partials/vector_functions.adoc       |  41 +++++
 src/java/org/apache/cassandra/cql3/CQL3Type.java   |  14 +-
 .../cassandra/cql3/functions/FunctionFactory.java  |  14 +-
 .../cql3/functions/FunctionParameter.java          |  86 ++++++++--
 .../cassandra/cql3/functions/NativeFunctions.java  |   1 +
 .../cassandra/cql3/functions/VectorFcts.java       | 105 ++++++++++++
 .../functions/masking/ReplaceMaskingFunction.java  |   4 +-
 .../cassandra/cql3/functions/types/VectorType.java |   2 +-
 .../cassandra/cql3/selection/Selectable.java       |  25 ++-
 .../apache/cassandra/db/marshal/VectorType.java    |   2 +-
 .../cassandra/cql3/functions/VectorFctsTest.java   | 179 +++++++++++++++++++++
 .../cql3/validation/operations/CQLVectorTest.java  |   4 +-
 .../cassandra/utils/AbstractTypeGenerators.java    |   6 +-
 17 files changed, 453 insertions(+), 40 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index 12691ced64..3f78b9769e 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 5.0
+ * Add vector similarity functions (CASSANDRA-18640)
  * Lift MessagingService.minimum_version to 40 (CASSANDRA-18314)
  * Introduce pluggable crypto providers and default to 
AmazonCorrettoCryptoProvider (CASSANDRA-18624)
  * Improved DeletionTime serialization (CASSANDRA-18648)
diff --git a/NEWS.txt b/NEWS.txt
index 1c5be180e8..378ae6e93c 100644
--- a/NEWS.txt
+++ b/NEWS.txt
@@ -84,6 +84,7 @@ New features
       class. Further details and documentation can be found in
       src/java/org/apache/cassandra/db/compaction/UnifiedCompactionStrategy.md
     - New `VectorType` (cql `vector<element_type, dimension>`) which adds new 
fixed-length element arrays. See CASSANDRA-18504
+    - Added new vector similarity functions `similarity_cosine`, 
`similarity_euclidean` and `similarity_dot_product`.
     - Removed UDT type migration logic for 3.6+ clusters upgrading to 4.0.  If 
migration has been disabled, it must be
       enabled before upgrading to 5.0 if the cluster used UDTs. See 
CASSANDRA-18504
     - Entended max expiration time from 2038-01-19T03:14:06+00:00 to 
2106-02-07T06:28:13+00:00
diff --git a/doc/modules/cassandra/pages/developing/cql/changes.adoc 
b/doc/modules/cassandra/pages/developing/cql/changes.adoc
index 9dc3af00cb..7bf6bd6d03 100644
--- a/doc/modules/cassandra/pages/developing/cql/changes.adoc
+++ b/doc/modules/cassandra/pages/developing/cql/changes.adoc
@@ -4,6 +4,7 @@ The following describes the changes in each version of CQL.
 
 == 3.4.7
 
+* Add vector similarity functions (`18640`)
 * Remove deprecated functions `dateOf` and `unixTimestampOf`, replaced by 
`toTimestamp` and `toUnixTimestamp` (`18328`)
 * Added support for attaching masking functions to table columns (`18068`)
 * Add UNMASK permission (`18069`)
diff --git a/doc/modules/cassandra/pages/developing/cql/functions.adoc 
b/doc/modules/cassandra/pages/developing/cql/functions.adoc
index 6474b72cb5..244af3e7cf 100644
--- a/doc/modules/cassandra/pages/developing/cql/functions.adoc
+++ b/doc/modules/cassandra/pages/developing/cql/functions.adoc
@@ -284,6 +284,13 @@ A number of functions allow to obscure the real contents 
of a column containing
 
 include::partial$masking_functions.adoc[]
 
+[[vector-similarity-functions]]
+===== Vector similarity functions
+
+A number of functions allow to obtain the similarity score between vectors of 
floats.
+
+include::partial$vector_functions.adoc[]
+
 [[user-defined-scalar-functions]]
 ==== User-defined functions
 
diff --git a/doc/modules/cassandra/partials/vector_functions.adoc 
b/doc/modules/cassandra/partials/vector_functions.adoc
new file mode 100644
index 0000000000..daa4b2b8ce
--- /dev/null
+++ b/doc/modules/cassandra/partials/vector_functions.adoc
@@ -0,0 +1,41 @@
+[cols=",",options="header",]
+|===
+|Function | Description
+
+| `similarity_cosine(vector, vector)` | Calculates the cosine similarity score 
between two float vectors of the same dimension.
+
+Examples:
+
+`similarity_cosine([0.1, 0.2], null)` -> `null`
+
+`similarity_cosine([0.1, 0.2], [0.1, 0.2])` -> `1`
+
+`similarity_cosine([0.1, 0.2], [-0.1, -0.2])` -> `0`
+
+`similarity_cosine([0.1, 0.2], [0.9, 0.8])` -> `0.964238`
+
+| `similarity_euclidean(vector, vector)` | Calculates the euclidian distance 
between two float vectors of the same dimension.
+
+Examples:
+
+`similarity_euclidean([0.1, 0.2], null)` -> `null`
+
+`similarity_euclidean([0.1, 0.2], [0.1, 0.2])` -> `1`
+
+`similarity_euclidean([0.1, 0.2], [-0.1, -0.2])` -> `0.833333`
+
+`similarity_euclidean([0.1, 0.2], [0.9, 0.8])` -> `0.5`
+
+| `similarity_dot_product(vector, vector)` | Calculates the dot product 
between two float vectors of the same dimension.
+
+Examples:
+
+`similarity_dot_product([0.1, 0.2], null)` -> `null`
+
+`similarity_dot_product([0.1, 0.2], [0.1, 0.2])` -> `0.525`
+
+`similarity_dot_product([0.1, 0.2], [-0.1, -0.2])` -> `0.475`
+
+`similarity_dot_product([0.1, 0.2], [0.9, 0.8])` -> `0.625`
+
+|===
\ No newline at end of file
diff --git a/src/java/org/apache/cassandra/cql3/CQL3Type.java 
b/src/java/org/apache/cassandra/cql3/CQL3Type.java
index 57474103bb..acb8bc97e2 100644
--- a/src/java/org/apache/cassandra/cql3/CQL3Type.java
+++ b/src/java/org/apache/cassandra/cql3/CQL3Type.java
@@ -834,13 +834,13 @@ public interface CQL3Type
         private static class RawVector extends Raw
         {
             private final CQL3Type.Raw element;
-            private final int dimention;
+            private final int dimension;
 
-            private RawVector(Raw element, int dimention)
+            private RawVector(Raw element, int dimension)
             {
                 super(true);
                 this.element = element;
-                this.dimention = dimention;
+                this.dimension = dimension;
             }
 
             @Override
@@ -865,7 +865,13 @@ public interface CQL3Type
             public CQL3Type prepare(String keyspace, Types udts) throws 
InvalidRequestException
             {
                 CQL3Type type = element.prepare(keyspace, udts);
-                return new Vector(type.getType(), dimention);
+                return new Vector(type.getType(), dimension);
+            }
+
+            @Override
+            public String toString()
+            {
+                return "vector<" + element.toString() + ", " + dimension + '>';
             }
         }
 
diff --git a/src/java/org/apache/cassandra/cql3/functions/FunctionFactory.java 
b/src/java/org/apache/cassandra/cql3/functions/FunctionFactory.java
index 32db9fed20..90a2c69d77 100644
--- a/src/java/org/apache/cassandra/cql3/functions/FunctionFactory.java
+++ b/src/java/org/apache/cassandra/cql3/functions/FunctionFactory.java
@@ -90,9 +90,19 @@ public abstract class FunctionFactory
         if (numArgs < numMandatoryParameters || numArgs > numParameters)
             throw invalidNumberOfArgumentsException();
 
-        // try to infer the types of the arguments
+        // Do a first pass trying to infer the types of the arguments 
individually, without any context about the types
+        // of the other arguments. We don't do any validation during this 
first pass.
         List<AbstractType<?>> types = new ArrayList<>(args.size());
         for (int i = 0; i < args.size(); i++)
+        {
+            AssignmentTestable arg = args.get(i);
+            FunctionParameter parameter = parameters.get(i);
+            
types.add(parameter.inferType(SchemaConstants.SYSTEM_KEYSPACE_NAME, arg, 
receiverType, null));
+        }
+
+        // Do a second pass trying to infer the types of the arguments 
considering the types of other inferred types.
+        // We can validate the inferred types during this second pass.
+        for (int i = 0; i < args.size(); i++)
         {
             AssignmentTestable arg = args.get(i);
             FunctionParameter parameter = parameters.get(i);
@@ -103,7 +113,7 @@ public abstract class FunctionFactory
                                                                 arg, this));
             parameter.validateType(name, arg, type);
             type = type.udfType();
-            types.add(type);
+            types.set(i, type);
         }
 
         return doGetOrCreateFunction(types, receiverType);
diff --git 
a/src/java/org/apache/cassandra/cql3/functions/FunctionParameter.java 
b/src/java/org/apache/cassandra/cql3/functions/FunctionParameter.java
index 083b0f61d4..708e4fdd7d 100644
--- a/src/java/org/apache/cassandra/cql3/functions/FunctionParameter.java
+++ b/src/java/org/apache/cassandra/cql3/functions/FunctionParameter.java
@@ -25,12 +25,14 @@ import javax.annotation.Nullable;
 
 import org.apache.cassandra.cql3.AssignmentTestable;
 import org.apache.cassandra.cql3.CQL3Type;
+import org.apache.cassandra.cql3.selection.Selectable;
 import org.apache.cassandra.db.marshal.AbstractType;
 import org.apache.cassandra.db.marshal.CollectionType;
 import org.apache.cassandra.db.marshal.ListType;
 import org.apache.cassandra.db.marshal.MapType;
 import org.apache.cassandra.db.marshal.NumberType;
 import org.apache.cassandra.db.marshal.SetType;
+import org.apache.cassandra.db.marshal.VectorType;
 import org.apache.cassandra.exceptions.InvalidRequestException;
 
 import static java.lang.String.format;
@@ -48,13 +50,14 @@ public interface FunctionParameter
      * @param keyspace the current keyspace
      * @param arg a parameter value in a specific function call
      * @param receiverType the type of the object that will receive the result 
of the function call
+     * @param inferredTypes the types that have been inferred for the other 
parameters
      * @return the inferred data type of the parameter, or {@link null} it 
isn't possible to infer it
      */
     @Nullable
     default AbstractType<?> inferType(String keyspace,
                                       AssignmentTestable arg,
                                       @Nullable AbstractType<?> receiverType,
-                                      List<AbstractType<?>> previousTypes)
+                                      @Nullable List<AbstractType<?>> 
inferredTypes)
     {
         return arg.getCompatibleTypeIfKnown(keyspace);
     }
@@ -83,9 +86,9 @@ public interface FunctionParameter
             public AbstractType<?> inferType(String keyspace,
                                              AssignmentTestable arg,
                                              @Nullable AbstractType<?> 
receiverType,
-                                             List<AbstractType<?>> 
previousTypes)
+                                             @Nullable List<AbstractType<?>> 
inferredTypes)
             {
-                return wrapped.inferType(keyspace, arg, receiverType, 
previousTypes);
+                return wrapped.inferType(keyspace, arg, receiverType, 
inferredTypes);
             }
 
             @Override
@@ -130,7 +133,7 @@ public interface FunctionParameter
             public AbstractType<?> inferType(String keyspace,
                                              AssignmentTestable arg,
                                              @Nullable AbstractType<?> 
receiverType,
-                                             List<AbstractType<?>> 
previousTypes)
+                                             @Nullable List<AbstractType<?>> 
inferredTypes)
             {
                 AbstractType<?> inferred = 
arg.getCompatibleTypeIfKnown(keyspace);
                 return inferred != null ? inferred : types[0].getType();
@@ -168,7 +171,7 @@ public interface FunctionParameter
             public AbstractType<?> inferType(String keyspace,
                                              AssignmentTestable arg,
                                              @Nullable AbstractType<?> 
receiverType,
-                                             List<AbstractType<?>> 
previousTypes)
+                                             @Nullable List<AbstractType<?>> 
inferredTypes)
             {
                 AbstractType<?> type = arg.getCompatibleTypeIfKnown(keyspace);
                 return type == null && inferFromReceiver ? receiverType : type;
@@ -189,9 +192,12 @@ public interface FunctionParameter
     }
 
     /**
-     * @return a function parameter definition that accepts values with the 
same type as the first parameter
+     * @param index the index of the function argument that this parameter is 
associated with
+     * @param preferOther whether the parameter should prefer the type of the 
other parameter over its own type
+     * @param parameter the type of this parameter when the type of the 
associated parameter is unknown
+     * @return a function parameter definition that is expected to have the 
same type as another parameter
      */
-    static FunctionParameter sameAsFirst()
+    static FunctionParameter sameAs(int index, boolean preferOther, 
FunctionParameter parameter)
     {
         return new FunctionParameter()
         {
@@ -199,21 +205,28 @@ public interface FunctionParameter
             public AbstractType<?> inferType(String keyspace,
                                              AssignmentTestable arg,
                                              @Nullable AbstractType<?> 
receiverType,
-                                             List<AbstractType<?>> 
previousTypes)
+                                             @Nullable List<AbstractType<?>> 
inferredTypes)
             {
-                return previousTypes.get(0);
+                if (preferOther)
+                {
+                    AbstractType<?> other = inferredTypes == null ? null : 
inferredTypes.get(index);
+                    return other == null ? parameter.inferType(keyspace, arg, 
receiverType, inferredTypes) : other;
+                }
+
+                AbstractType<?> inferred = parameter.inferType(keyspace, arg, 
receiverType, inferredTypes);
+                return inferred == null && inferredTypes != null ? 
inferredTypes.get(index) : inferred;
             }
 
             @Override
             public void validateType(FunctionName name, AssignmentTestable 
arg, AbstractType<?> argType)
             {
-                // nothing to do here, all types are accepted
+                parameter.validateType(name, arg, argType);
             }
 
             @Override
             public String toString()
             {
-                return "same";
+                return parameter.toString();
             }
         };
     }
@@ -336,4 +349,55 @@ public interface FunctionParameter
             }
         };
     }
+
+    /**
+     * @param type the type of the vector elements
+     * @return a function parameter definition that accepts values of type 
{@link VectorType} with elements of the
+     * specified {@code type} and any dimensions.
+     */
+    static FunctionParameter vector(CQL3Type type)
+    {
+        return new FunctionParameter()
+        {
+            @Override
+            public AbstractType<?> inferType(String keyspace,
+                                             AssignmentTestable arg,
+                                             @Nullable AbstractType<?> 
receiverType,
+                                             @Nullable List<AbstractType<?>> 
inferredTypes)
+            {
+                if (arg instanceof Selectable.WithArrayLiteral)
+                    return VectorType.getInstance(type.getType(), 
((Selectable.WithArrayLiteral) arg).getSize());
+
+                AbstractType<?> inferred = 
arg.getCompatibleTypeIfKnown(keyspace);
+                return inferred == null ? receiverType : inferred;
+            }
+
+            @Override
+            public void validateType(FunctionName name, AssignmentTestable 
arg, AbstractType<?> argType)
+            {
+                if (argType.isVector())
+                {
+                    VectorType<?> vectorType = (VectorType<?>) argType;
+                    if (vectorType.elementType.asCQL3Type() == type)
+                        return;
+                }
+                else if (argType instanceof ListType) // if it's terminal it 
will be a list
+                {
+                    ListType<?> listType = (ListType<?>) argType;
+                    if 
(listType.getElementsType().testAssignment(type.getType()) == NOT_ASSIGNABLE)
+                        return;
+                }
+
+                throw new InvalidRequestException(format("Function %s requires 
a %s vector argument, " +
+                                                         "but found argument 
%s of type %s",
+                                                         name, type, arg, 
argType.asCQL3Type()));
+            }
+
+            @Override
+            public String toString()
+            {
+                return format("vector<%s, n>", type);
+            }
+        };
+    }
 }
diff --git a/src/java/org/apache/cassandra/cql3/functions/NativeFunctions.java 
b/src/java/org/apache/cassandra/cql3/functions/NativeFunctions.java
index 02939e9896..2100fe3f89 100644
--- a/src/java/org/apache/cassandra/cql3/functions/NativeFunctions.java
+++ b/src/java/org/apache/cassandra/cql3/functions/NativeFunctions.java
@@ -46,6 +46,7 @@ public class NativeFunctions
             BytesConversionFcts.addFunctionsTo(this);
             MathFcts.addFunctionsTo(this);
             MaskingFcts.addFunctionsTo(this);
+            VectorFcts.addFunctionsTo(this);
         }
     };
 
diff --git a/src/java/org/apache/cassandra/cql3/functions/VectorFcts.java 
b/src/java/org/apache/cassandra/cql3/functions/VectorFcts.java
new file mode 100644
index 0000000000..31136dcfd4
--- /dev/null
+++ b/src/java/org/apache/cassandra/cql3/functions/VectorFcts.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.cql3.functions;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+
+import org.apache.cassandra.cql3.CQL3Type;
+import org.apache.cassandra.db.marshal.AbstractType;
+import org.apache.cassandra.db.marshal.FloatType;
+import org.apache.cassandra.db.marshal.VectorType;
+import org.apache.cassandra.exceptions.InvalidRequestException;
+import org.apache.cassandra.transport.ProtocolVersion;
+import org.apache.lucene.index.VectorSimilarityFunction;
+
+public class VectorFcts
+{
+    public static void addFunctionsTo(NativeFunctions functions)
+    {
+        functions.add(createSimilarityFunctionFactory("similarity_cosine", 
VectorSimilarityFunction.COSINE, false));
+        functions.add(createSimilarityFunctionFactory("similarity_euclidean", 
VectorSimilarityFunction.EUCLIDEAN, true));
+        
functions.add(createSimilarityFunctionFactory("similarity_dot_product", 
VectorSimilarityFunction.DOT_PRODUCT, true));
+    }
+
+    private static FunctionFactory createSimilarityFunctionFactory(String name,
+                                                                   
VectorSimilarityFunction luceneFunction,
+                                                                   boolean 
supportsZeroVectors)
+    {
+        return new FunctionFactory(name,
+                                   FunctionParameter.sameAs(1, false, 
FunctionParameter.vector(CQL3Type.Native.FLOAT)),
+                                   FunctionParameter.sameAs(0, false, 
FunctionParameter.vector(CQL3Type.Native.FLOAT)))
+        {
+            @Override
+            @SuppressWarnings("unchecked")
+            protected NativeFunction 
doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> 
receiverType)
+            {
+                // check that all arguments have the same vector dimensions
+                VectorType<Float> firstArgType = (VectorType<Float>) 
argTypes.get(0);
+                int dimensions = firstArgType.dimension;
+                if (!argTypes.stream().allMatch(t -> ((VectorType<?>) 
t).dimension == dimensions))
+                    throw new InvalidRequestException("All arguments must have 
the same vector dimensions");
+                return createSimilarityFunction(name.name, firstArgType, 
luceneFunction, supportsZeroVectors);
+            }
+        };
+    }
+
+    private static NativeFunction createSimilarityFunction(String name,
+                                                           VectorType<Float> 
type,
+                                                           
VectorSimilarityFunction f,
+                                                           boolean 
supportsZeroVectors)
+    {
+        return new NativeScalarFunction(name, FloatType.instance, type, type)
+        {
+            @Override
+            public Arguments newArguments(ProtocolVersion version)
+            {
+                return new FunctionArguments(version,
+                                             (v, b) -> type.composeAsFloat(b),
+                                             (v, b) -> type.composeAsFloat(b));
+            }
+
+            @Override
+            public ByteBuffer execute(Arguments arguments) throws 
InvalidRequestException
+            {
+                if (arguments.containsNulls())
+                    return null;
+
+                float[] v1 = arguments.get(0);
+                float[] v2 = arguments.get(1);
+
+                if (!supportsZeroVectors)
+                {
+                    if (isAllZero(v1) || isAllZero(v2))
+                        throw new InvalidRequestException("Function " + name + 
" doesn't support all-zero vectors.");
+                }
+
+                return FloatType.instance.decompose(f.compare(v1, v2));
+            }
+
+            private boolean isAllZero(float[] v)
+            {
+                for (float f : v)
+                    if (f != 0)
+                        return false;
+                return true;
+            }
+        };
+    }
+}
diff --git 
a/src/java/org/apache/cassandra/cql3/functions/masking/ReplaceMaskingFunction.java
 
b/src/java/org/apache/cassandra/cql3/functions/masking/ReplaceMaskingFunction.java
index 880f941abf..373743daf4 100644
--- 
a/src/java/org/apache/cassandra/cql3/functions/masking/ReplaceMaskingFunction.java
+++ 
b/src/java/org/apache/cassandra/cql3/functions/masking/ReplaceMaskingFunction.java
@@ -62,7 +62,9 @@ public class ReplaceMaskingFunction extends MaskingFunction
     /** @return a {@link FunctionFactory} to build new {@link 
ReplaceMaskingFunction}s. */
     public static FunctionFactory factory()
     {
-        return new MaskingFunction.Factory(NAME, 
FunctionParameter.anyType(true), FunctionParameter.sameAsFirst())
+        return new MaskingFunction.Factory(NAME,
+                                           FunctionParameter.anyType(true),
+                                           FunctionParameter.sameAs(0, true, 
FunctionParameter.anyType(true)))
         {
             @Override
             protected NativeFunction 
doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> 
receiverType)
diff --git a/src/java/org/apache/cassandra/cql3/functions/types/VectorType.java 
b/src/java/org/apache/cassandra/cql3/functions/types/VectorType.java
index 8d22322424..1c546aa78f 100644
--- a/src/java/org/apache/cassandra/cql3/functions/types/VectorType.java
+++ b/src/java/org/apache/cassandra/cql3/functions/types/VectorType.java
@@ -31,7 +31,7 @@ public class VectorType extends DataType
     VectorType(DataType subtype, int dimensions)
     {
         super(Name.VECTOR);
-        assert dimensions > 0 : "vectors may only have positive dimentions; 
given " + dimensions;
+        assert dimensions > 0 : "vectors may only have positive dimensions; 
given " + dimensions;
         this.subtype = subtype;
         this.dimensions = dimensions;
     }
diff --git a/src/java/org/apache/cassandra/cql3/selection/Selectable.java 
b/src/java/org/apache/cassandra/cql3/selection/Selectable.java
index 5c06da64e8..56f7f85025 100644
--- a/src/java/org/apache/cassandra/cql3/selection/Selectable.java
+++ b/src/java/org/apache/cassandra/cql3/selection/Selectable.java
@@ -726,7 +726,7 @@ public interface Selectable extends AssignmentTestable
         /**
          * The list elements
          */
-        private final List<Selectable> selectables;
+        protected final List<Selectable> selectables;
 
         public WithArrayLiteral(List<Selectable> selectables)
         {
@@ -787,6 +787,11 @@ public interface Selectable extends AssignmentTestable
             return Lists.listToString(selectables);
         }
 
+        public int getSize()
+        {
+            return selectables.size();
+        }
+
         public static class Raw implements Selectable.Raw
         {
             private final List<Selectable.Raw> raws;
@@ -804,16 +809,11 @@ public interface Selectable extends AssignmentTestable
         }
     }
 
-    public static class WithList implements Selectable
+    public static class WithList extends WithArrayLiteral
     {
-        /**
-         * The list elements
-         */
-        private final List<Selectable> selectables;
-
         public WithList(List<Selectable> selectables)
         {
-            this.selectables = selectables;
+            super(selectables);
         }
 
         @Override
@@ -876,16 +876,11 @@ public interface Selectable extends AssignmentTestable
         }
     }
 
-    public static class WithVector implements Selectable
+    public static class WithVector extends WithArrayLiteral
     {
-        /**
-         * The vector elements
-         */
-        private final List<Selectable> selectables;
-
         public WithVector(List<Selectable> selectables)
         {
-            this.selectables = selectables;
+            super(selectables);
         }
 
         @Override
diff --git a/src/java/org/apache/cassandra/db/marshal/VectorType.java 
b/src/java/org/apache/cassandra/db/marshal/VectorType.java
index 631d990c6c..7caf02e3bb 100644
--- a/src/java/org/apache/cassandra/db/marshal/VectorType.java
+++ b/src/java/org/apache/cassandra/db/marshal/VectorType.java
@@ -84,7 +84,7 @@ public final class VectorType<T> extends AbstractType<List<T>>
     {
         super(ComparisonType.CUSTOM);
         if (dimension <= 0)
-            throw new InvalidRequestException(String.format("vectors may only 
have positive dimentions; given %d", dimension));
+            throw new InvalidRequestException(String.format("vectors may only 
have positive dimensions; given %d", dimension));
         this.elementType = elementType;
         this.dimension = dimension;
         this.elementSerializer = elementType.getSerializer();
diff --git a/test/unit/org/apache/cassandra/cql3/functions/VectorFctsTest.java 
b/test/unit/org/apache/cassandra/cql3/functions/VectorFctsTest.java
new file mode 100644
index 0000000000..207c00ea87
--- /dev/null
+++ b/test/unit/org/apache/cassandra/cql3/functions/VectorFctsTest.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.cql3.functions;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import org.apache.cassandra.cql3.CQLTester;
+import org.apache.lucene.index.VectorSimilarityFunction;
+
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+@RunWith(Parameterized.class)
+public class VectorFctsTest extends CQLTester
+{
+    @Parameterized.Parameter
+    public String function;
+
+    @Parameterized.Parameter(1)
+    public VectorSimilarityFunction luceneFunction;
+
+    @Parameterized.Parameters(name = "{index}: function={0}")
+    public static Collection<Object[]> data()
+    {
+        return Arrays.asList(new Object[][]{
+        { "system.similarity_cosine", VectorSimilarityFunction.COSINE },
+        { "system.similarity_euclidean", VectorSimilarityFunction.EUCLIDEAN },
+        { "system.similarity_dot_product", 
VectorSimilarityFunction.DOT_PRODUCT }
+        });
+    }
+
+    @Test
+    public void testVectorSimilarityFunction()
+    {
+        createTable(KEYSPACE, "CREATE TABLE %s (pk int PRIMARY KEY, value 
vector<float, 2>, " +
+                              "l list<float>, " + // lists shouldn't be 
accepted by the functions
+                              "fl frozen<list<float>>, " + // frozen lists 
shouldn't be accepted by the functions
+                              "v1 vector<float, 1>, " + // 1-dimension vector 
to test missmatching dimensions
+                              "v_int vector<int, 2>, " + // int vectors 
shouldn't be accepted by the functions
+                              "v_double vector<double, 2>)");// double vectors 
shouldn't be accepted by the functions
+
+        float[] values = new float[]{ 1f, 2f };
+        CQLTester.Vector<Float> vector = vector(ArrayUtils.toObject(values));
+        Object[] similarity = row(luceneFunction.compare(values, values));
+
+        // basic functionality
+        execute("INSERT INTO %s (pk, value, l, fl, v1, v_int, v_double) VALUES 
(0, ?, ?, ?, ?, ?, ?)",
+                vector, list(1f, 2f), list(1f, 2f), vector(1f), vector(1, 2), 
vector(1d, 2d));
+        assertRows(execute("SELECT " + function + "(value, value) FROM %s"), 
similarity);
+
+        // literals
+        assertRows(execute("SELECT " + function + "(value, [1, 2]) FROM %s"), 
similarity);
+        assertRows(execute("SELECT " + function + "([1, 2], value) FROM %s"), 
similarity);
+        assertRows(execute("SELECT " + function + "([1, 2], [1, 2]) FROM %s"), 
similarity);
+
+        // bind markers
+        assertRows(execute("SELECT " + function + "(value, ?) FROM %s", 
vector), similarity);
+        assertRows(execute("SELECT " + function + "(?, value) FROM %s", 
vector), similarity);
+        assertThatThrownBy(() -> execute("SELECT " + function + "(?, ?) FROM 
%s", vector, vector))
+        .hasMessageContaining("Cannot infer type of argument ?");
+
+        // bind markers with type hints
+        assertRows(execute("SELECT " + function + "((vector<float, 2>) ?, ?) 
FROM %s", vector, vector), similarity);
+        assertRows(execute("SELECT " + function + "(?, (vector<float, 2>) ?) 
FROM %s", vector, vector), similarity);
+        assertRows(execute("SELECT " + function + "((vector<float, 2>) ?, 
(vector<float, 2>) ?) FROM %s", vector, vector), similarity);
+
+        // bind markers and literals
+        assertRows(execute("SELECT " + function + "([1, 2], ?) FROM %s", 
vector), similarity);
+        assertRows(execute("SELECT " + function + "(?, [1, 2]) FROM %s", 
vector), similarity);
+        assertRows(execute("SELECT " + function + "([1, 2], ?) FROM %s", 
vector), similarity);
+
+        // wrong column types with columns
+        assertThatThrownBy(() -> execute("SELECT " + function + "(l, value) 
FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument l of type list<float>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "(fl, value) 
FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument fl of type frozen<list<float>>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "(value, l) 
FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument l of type list<float>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "(value, fl) 
FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument fl of type frozen<list<float>>");
+
+        // wrong column types with columns and literals
+        assertThatThrownBy(() -> execute("SELECT " + function + "(l, [1, 2]) 
FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument l of type list<float>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "(fl, [1, 2]) 
FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument fl of type frozen<list<float>>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], l) 
FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument l of type list<float>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], fl) 
FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument fl of type frozen<list<float>>");
+
+        // wrong column types with cast literals
+        assertThatThrownBy(() -> execute("SELECT " + function + 
"((List<Float>)[1, 2], [3, 4]) FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument (list<float>)[1, 2] of type 
frozen<list<float>>");
+        assertThatThrownBy(() -> execute("SELECT " + function + 
"((List<Float>)[1, 2], (List<Float>)[3, 4]) FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument (list<float>)[1, 2] of type 
frozen<list<float>>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], 
(List<Float>)[3, 4]) FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument (list<float>)[3, 4] of type 
frozen<list<float>>");
+
+        // wrong non-float vectors
+        assertThatThrownBy(() -> execute("SELECT " + function + "(v_int, [1, 
2]) FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument v_int of type vector<int, 2>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "(v_double, 
[1, 2]) FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument v_double of type vector<double, 2>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], 
v_int) FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument v_int of type vector<int, 2>");
+        assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], 
v_double) FROM %s"))
+        .hasMessageContaining("Function " + function + " requires a float 
vector argument, but found argument v_double of type vector<double, 2>");
+
+        // mismatching dimensions with literals
+        assertThatThrownBy(() -> execute("SELECT " + function + "([1, 2], [3]) 
FROM %s", vector(1)))
+        .hasMessageContaining("All arguments must have the same vector 
dimensions");
+        assertThatThrownBy(() -> execute("SELECT " + function + "(value, [1]) 
FROM %s", vector(1)))
+        .hasMessageContaining("All arguments must have the same vector 
dimensions");
+        assertThatThrownBy(() -> execute("SELECT " + function + "([1], value) 
FROM %s", vector(1)))
+        .hasMessageContaining("All arguments must have the same vector 
dimensions");
+
+        // mismatching dimensions with bind markers
+        assertThatThrownBy(() -> execute("SELECT " + function + 
"((vector<float, 1>) ?, value) FROM %s", vector(1)))
+        .hasMessageContaining("All arguments must have the same vector 
dimensions");
+        assertThatThrownBy(() -> execute("SELECT " + function + "(value, 
(vector<float, 1>) ?) FROM %s", vector(1)))
+        .hasMessageContaining("All arguments must have the same vector 
dimensions");
+        assertThatThrownBy(() -> execute("SELECT " + function + 
"((vector<float, 2>) ?, (vector<float, 1>) ?) FROM %s", vector(1, 2), 
vector(1)))
+        .hasMessageContaining("All arguments must have the same vector 
dimensions");
+
+        // mismatching dimensions with columns
+        assertThatThrownBy(() -> execute("SELECT " + function + "(value, v1) 
FROM %s"))
+        .hasMessageContaining("All arguments must have the same vector 
dimensions");
+        assertThatThrownBy(() -> execute("SELECT " + function + "(v1, value) 
FROM %s"))
+        .hasMessageContaining("All arguments must have the same vector 
dimensions");
+
+        // null arguments with literals
+        assertRows(execute("SELECT " + function + "(value, null) FROM %s"), 
row((Float) null));
+        assertRows(execute("SELECT " + function + "(null, value) FROM %s"), 
row((Float) null));
+        assertThatThrownBy(() -> execute("SELECT " + function + "(null, null) 
FROM %s"))
+        .hasMessageContaining("Cannot infer type of argument NULL in call to 
function " + function);
+
+        // null arguments with bind markers
+        assertRows(execute("SELECT " + function + "(value, ?) FROM %s", 
(CQLTester.Vector<Float>) null), row((Float) null));
+        assertRows(execute("SELECT " + function + "(?, value) FROM %s", 
(CQLTester.Vector<Float>) null), row((Float) null));
+        assertThatThrownBy(() -> execute("SELECT " + function + "(?, ?) FROM 
%s", null, null))
+        .hasMessageContaining("Cannot infer type of argument ? in call to 
function " + function);
+
+        // test all-zero vectors, only cosine similarity should reject them
+        if (luceneFunction == VectorSimilarityFunction.COSINE)
+        {
+            String expected = "Function " + function + " doesn't support 
all-zero vectors";
+            assertThatThrownBy(() -> execute("SELECT " + function + "(value, 
[0, 0]) FROM %s")) .hasMessageContaining(expected);
+            assertThatThrownBy(() -> execute("SELECT " + function + "([0, 0], 
value) FROM %s")).hasMessageContaining(expected);
+        }
+        else
+        {
+            float expected = luceneFunction.compare(values, new float[]{ 0, 0 
});
+            assertRows(execute("SELECT " + function + "(value, [0, 0]) FROM 
%s"), row(expected));
+            assertRows(execute("SELECT " + function + "([0, 0], value) FROM 
%s"), row(expected));
+        }
+    }
+}
diff --git 
a/test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java 
b/test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java
index 773a08cf38..cf88d32d93 100644
--- 
a/test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java
+++ 
b/test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java
@@ -457,14 +457,14 @@ public class CQLVectorTest extends CQLTester.InMemory
                                   format("SELECT %s([1, 2, 3], [4, 5, 6]) FROM 
%%s", f));
 
         // Test wrong types on function creation
-        assertInvalidThrowMessage("vectors may only have positive dimentions; 
given 0",
+        assertInvalidThrowMessage("vectors may only have positive dimensions; 
given 0",
                                   InvalidRequestException.class,
                                   "CREATE FUNCTION %s (x vector<int, 0>) " +
                                   "CALLED ON NULL INPUT " +
                                   "RETURNS vector<int, 2> " +
                                   "LANGUAGE java " +
                                   "AS 'return x;'");
-        assertInvalidThrowMessage("vectors may only have positive dimentions; 
given 0",
+        assertInvalidThrowMessage("vectors may only have positive dimensions; 
given 0",
                                   InvalidRequestException.class,
                                   "CREATE FUNCTION %s (x vector<int, 2>) " +
                                   "CALLED ON NULL INPUT " +
diff --git a/test/unit/org/apache/cassandra/utils/AbstractTypeGenerators.java 
b/test/unit/org/apache/cassandra/utils/AbstractTypeGenerators.java
index d9e3f80508..0b23a51c5d 100644
--- a/test/unit/org/apache/cassandra/utils/AbstractTypeGenerators.java
+++ b/test/unit/org/apache/cassandra/utils/AbstractTypeGenerators.java
@@ -520,15 +520,15 @@ public final class AbstractTypeGenerators
         return vectorTypeGen(typeGen, SourceDSL.integers().between(1, 100));
     }
 
-    public static Gen<VectorType<?>> vectorTypeGen(Gen<AbstractType<?>> 
typeGen, Gen<Integer> dimentionGen)
+    public static Gen<VectorType<?>> vectorTypeGen(Gen<AbstractType<?>> 
typeGen, Gen<Integer> dimensionGen)
     {
         return rnd -> {
-            int dimention = dimentionGen.generate(rnd);
+            int dimension = dimensionGen.generate(rnd);
             AbstractType<?> element = typeGen.generate(rnd);
             // empty type not supported
             while (element == EmptyType.instance)
                 element = typeGen.generate(rnd);
-            return VectorType.getInstance(element.freeze(), dimention);
+            return VectorType.getInstance(element.freeze(), dimension);
         };
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to