This is an automated email from the ASF dual-hosted git repository.

corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new c115613de1 [Feature][Transform V2] Add vector dimension reduction 
transform (#9783)
c115613de1 is described below

commit c115613de1c62af680a43363a766f72165be9bfc
Author: CosmosNi <[email protected]>
AuthorDate: Tue Sep 2 15:25:42 2025 +0800

    [Feature][Transform V2] Add vector dimension reduction transform (#9783)
---
 docs/en/transform-v2/sql-functions.md              |  42 +++-
 docs/zh/transform-v2/sql-functions.md              |  41 +++-
 .../apache/seatunnel/e2e/transform/TestSQLIT.java  |  12 ++
 .../test/resources/sql_transform/func_vector.conf  | 142 +++++++++++++
 .../transform/sql/zeta/ZetaSQLFunction.java        |   8 +
 .../seatunnel/transform/sql/zeta/ZetaSQLType.java  |   5 +
 .../sql/zeta/functions/VectorFunction.java         | 156 ++++++++++++++
 .../transform/sql/SQLVectorFunctionTest.java       | 235 +++++++++++++++++++++
 8 files changed, 639 insertions(+), 2 deletions(-)

diff --git a/docs/en/transform-v2/sql-functions.md 
b/docs/en/transform-v2/sql-functions.md
index 7fcaf0344c..857932c2bd 100644
--- a/docs/en/transform-v2/sql-functions.md
+++ b/docs/en/transform-v2/sql-functions.md
@@ -1221,4 +1221,44 @@ Calculates the Euclidean (L2) distance between two 
vectors.
 
 Example:
 
-L2_DISTANCE(vector1, vector2)
\ No newline at end of file
+L2_DISTANCE(vector1, vector2)
+
+### VECTOR_REDUCE
+
+```VECTOR_REDUCE(vector_field, target_dimension, method)```
+
+Generic vector dimension reduction function that supports multiple reduction 
methods.
+
+**Parameters:**
+- `vector_field`: The vector field to reduce (VECTOR type)
+- `target_dimension`: The target dimension (INTEGER, must be smaller than 
source dimension)
+- `method`: The reduction method (STRING):
+  - **'TRUNCATE'**: Truncates the vector by keeping only the first N elements. 
This is the simplest and fastest dimension reduction method, but may lose 
important information in the truncated dimensions.
+  - **'RANDOM_PROJECTION'**: Uses Gaussian random projection with normally 
distributed random matrix. This method preserves relative distances between 
vectors while reducing dimensionality, following the Johnson-Lindenstrauss 
lemma.
+  - **'SPARSE_RANDOM_PROJECTION'**: Uses sparse random projection where matrix 
elements are mostly zero (±√3, 0). This is more computationally efficient than 
regular random projection while maintaining similar distance preservation 
properties.
+
+**Returns:** VECTOR type with reduced dimensions
+
+**Example:**
+```sql
+SELECT id, VECTOR_REDUCE(embedding, 256, 'TRUNCATE') as reduced_embedding FROM 
table
+SELECT id, VECTOR_REDUCE(embedding, 128, 'RANDOM_PROJECTION') as 
reduced_embedding FROM table
+SELECT id, VECTOR_REDUCE(embedding, 64, 'SPARSE_RANDOM_PROJECTION') as 
reduced_embedding FROM table
+```
+
+### VECTOR_NORMALIZE
+
+```VECTOR_NORMALIZE(vector_field)```
+
+Normalizes a vector to unit length (magnitude = 1). This is useful for 
computing cosine similarity.
+
+**Parameters:**
+- `vector_field`: The vector field to normalize (VECTOR type)
+
+**Returns:** VECTOR type - the normalized vector
+
+**Example:**
+```sql
+SELECT id, VECTOR_NORMALIZE(embedding) as normalized_embedding FROM table
+```
+
diff --git a/docs/zh/transform-v2/sql-functions.md 
b/docs/zh/transform-v2/sql-functions.md
index ad47beeb4a..188d12f665 100644
--- a/docs/zh/transform-v2/sql-functions.md
+++ b/docs/zh/transform-v2/sql-functions.md
@@ -1215,4 +1215,43 @@ L1_DISTANCE(vector1, vector2)
 
 示例:
 
-L2_DISTANCE(vector1, vector2)
\ No newline at end of file
+L2_DISTANCE(vector1, vector2)
+
+### VECTOR_REDUCE
+
+```VECTOR_REDUCE(vector_field, target_dimension, method)```
+
+通用向量降维函数,支持多种降维方法。
+
+**参数:**
+- `vector_field`: 要降维的向量字段 (VECTOR 类型)
+- `target_dimension`: 目标维度 (INTEGER,必须小于源维度)
+- `method`: 降维方法 (STRING):
+  - **'TRUNCATE'**: 截断法,通过保留前N个元素来缩减向量维度。这是最简单、最快速的降维方法,但可能会丢失被截断维度中的重要信息。
+  - **'RANDOM_PROJECTION'**: 
随机投影法,使用高斯随机投影和正态分布的随机矩阵。该方法在降维的同时保持向量间的相对距离,遵循Johnson-Lindenstrauss引理。
+  - **'SPARSE_RANDOM_PROJECTION'**: 稀疏随机投影法,矩阵元素大多为零(±√3, 
0)。比常规随机投影在计算上更高效,同时保持相似的距离保持特性。
+
+**返回值:** 降维后的 VECTOR 类型
+
+**示例:**
+```sql
+SELECT id, VECTOR_REDUCE(embedding, 256, 'TRUNCATE') as reduced_embedding FROM 
table
+SELECT id, VECTOR_REDUCE(embedding, 128, 'RANDOM_PROJECTION') as 
reduced_embedding FROM table
+SELECT id, VECTOR_REDUCE(embedding, 64, 'SPARSE_RANDOM_PROJECTION') as 
reduced_embedding FROM table
+```
+
+### VECTOR_NORMALIZE
+
+```VECTOR_NORMALIZE(vector_field)```
+
+将向量归一化为单位长度(模长 = 1)。这对于计算余弦相似度很有用。
+
+**参数:**
+- `vector_field`: 要归一化的向量字段 (VECTOR 类型)
+
+**返回值:** VECTOR 类型 - 归一化后的向量
+
+**示例:**
+```sql
+SELECT id, VECTOR_NORMALIZE(embedding) as normalized_embedding FROM table
+```
\ No newline at end of file
diff --git 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java
 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java
index cb588d0aef..6d9dd11856 100644
--- 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java
+++ 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java
@@ -86,6 +86,18 @@ public class TestSQLIT extends TestSuiteBase {
         Assertions.assertEquals(0, multiIfSql.getExitCode());
     }
 
+    @TestTemplate
+    @DisabledOnContainer(
+            value = {},
+            type = {EngineType.SPARK},
+            disabledReason = "Vector functions are not supported in Spark 
engine")
+    public void testVectorFunctions(TestContainer container)
+            throws IOException, InterruptedException {
+        Container.ExecResult vectorFunctionResult =
+                container.executeJob("/sql_transform/func_vector.conf");
+        Assertions.assertEquals(0, vectorFunctionResult.getExitCode());
+    }
+
     @TestTemplate
     public void testSQLTransformMultiTable(TestContainer container)
             throws IOException, InterruptedException {
diff --git 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_vector.conf
 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_vector.conf
new file mode 100644
index 0000000000..daaaa3a8d5
--- /dev/null
+++ 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_vector.conf
@@ -0,0 +1,142 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+######
+###### This config file is a demonstration of vector functions in SQL transform
+######
+
+env {
+  parallelism = 1
+  job.mode = "BATCH"
+  checkpoint.interval = 10000
+}
+
+source {
+  FakeSource {
+    plugin_output = "fake"
+    schema = {
+      fields {
+        id = "int"
+        name = "string"
+        vector_field = "array<float>"
+        vector_field2 = "array<float>"
+      }
+    }
+    rows = [
+      {
+        fields = [1, "test1", [1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 
5.0]]
+        kind = INSERT
+      },
+      {
+        fields = [2, "test2", [2.0, 4.0, 6.0, 8.0, 10.0], [0.6, 0.8, 0.0, 0.0, 
0.0]]
+        kind = INSERT
+      },
+      {
+        fields = [3, "test3", [3.0, 4.0, 0.0, 0.0, 0.0], [3.0, 4.0, 0.0, 0.0, 
0.0]]
+        kind = INSERT
+      }
+    ]
+  }
+}
+
+transform {
+  Sql {
+    plugin_input = "fake"
+    plugin_output = "fake1"
+    query = """SELECT
+      id,
+      name,
+      VECTOR_DIMS(vector_field) as original_dim,
+      VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'TRUNCATE')) as truncated_dim,
+      VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'RANDOM_PROJECTION')) as 
projected_dim,
+      VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'SPARSE_RANDOM_PROJECTION')) 
as sparse_projected_dim,
+      VECTOR_DIMS(VECTOR_NORMALIZE(vector_field)) as normalized_dim
+    FROM dual"""
+  }
+}
+
+sink {
+  Assert {
+    plugin_input = "fake1"
+    rules = {
+      field_rules = [
+        {
+          field_name = "id"
+          field_type = "int"
+          field_value = [
+            {
+              rule_type = NOT_NULL
+            }
+          ]
+        },
+        {
+          field_name = "name"
+          field_type = "string"
+          field_value = [
+            {
+              rule_type = NOT_NULL
+            }
+          ]
+        },
+        {
+          field_name = "original_dim"
+          field_type = "int"
+          field_value = [
+            {equals_to = 5}
+          ]
+        },
+        {
+          field_name = "truncated_dim"
+          field_type = "int"
+          field_value = [
+            {equals_to = 3}
+          ]
+        },
+        {
+          field_name = "projected_dim"
+          field_type = "int"
+          field_value = [
+            {equals_to = 3}
+          ]
+        },
+        {
+          field_name = "sparse_projected_dim"
+          field_type = "int"
+          field_value = [
+            {equals_to = 3}
+          ]
+        },
+        {
+          field_name = "normalized_dim"
+          field_type = "int"
+          field_value = [
+            {equals_to = 5}
+          ]
+        }
+      ]
+      row_rules = [
+        {
+          rule_type = MAX_ROW
+          rule_value = 3
+        },
+        {
+          rule_type = MIN_ROW
+          rule_value = 3
+        }
+      ]
+    }
+  }
+}
\ No newline at end of file
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
index fef526d0d7..86eaa9ccad 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
@@ -212,6 +212,9 @@ public class ZetaSQLFunction {
     public static final String VECTOR_NORM = "VECTOR_NORM";
     public static final String INNER_PRODUCT = "INNER_PRODUCT";
 
+    public static final String VECTOR_REDUCE = "VECTOR_REDUCE";
+    public static final String VECTOR_NORMALIZE = "VECTOR_NORMALIZE";
+
     private final SeaTunnelRowType inputRowType;
 
     private final ZetaSQLType zetaSQLType;
@@ -619,6 +622,11 @@ public class ZetaSQLFunction {
                 return VectorFunction.vectorNorm(args);
             case INNER_PRODUCT:
                 return VectorFunction.innerProduct(args);
+            case VECTOR_REDUCE:
+                return VectorFunction.vectorReduce(
+                        args.get(0), (Integer) args.get(1), (String) 
args.get(2));
+            case VECTOR_NORMALIZE:
+                return VectorFunction.vectorNormalize(args.get(0));
             default:
                 for (ZetaUDF udf : udfList) {
                     if (udf.functionName().equalsIgnoreCase(functionName)) {
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
index 067fab2481..83c9550bed 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
@@ -25,6 +25,7 @@ import org.apache.seatunnel.api.table.type.MapType;
 import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
 import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
 import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.api.table.type.VectorType;
 import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
 import org.apache.seatunnel.transform.exception.TransformException;
 import org.apache.seatunnel.transform.sql.zeta.functions.ArrayFunction;
@@ -489,6 +490,10 @@ public class ZetaSQLType {
             case ZetaSQLFunction.MOD:
                 // Result has the same type as second argument
                 return 
getExpressionType(function.getParameters().getExpressions().get(1));
+                // Vector functions
+            case ZetaSQLFunction.VECTOR_REDUCE:
+            case ZetaSQLFunction.VECTOR_NORMALIZE:
+                return VectorType.VECTOR_FLOAT_TYPE;
             default:
                 for (ZetaUDF udf : udfList) {
                     if 
(udf.functionName().equalsIgnoreCase(function.getName())) {
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
index 7b37acdfbd..e10688702d 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
@@ -25,9 +25,11 @@ import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.stream.IntStream;
 
 public class VectorFunction {
+    private static final Random random = new Random(42);
 
     public static Object cosineDistance(List<Object> args) {
         if (args.size() != 2) {
@@ -199,4 +201,158 @@ public class VectorFunction {
                     String.format("Unsupported vector type: %s", 
obj.getClass().getName()));
         }
     }
+
+    /** Truncate vector to target dimension Usage: VECTOR_REDUCE(embedding, 
256, 'TRUNCATE') */
+    public static Object vectorTruncate(Object vectorData, Integer 
targetDimension) {
+        if (vectorData == null || targetDimension == null) {
+            return null;
+        }
+
+        Float[] sourceVector = convertToFloatArray(vectorData);
+        if (sourceVector.length <= targetDimension) {
+            return vectorData; // No need to truncate
+        }
+
+        Float[] result = new Float[targetDimension];
+        System.arraycopy(sourceVector, 0, result, 0, targetDimension);
+        return VectorUtils.toByteBuffer(result);
+    }
+
+    /**
+     * Random projection for dimension reduction Usage: 
VECTOR_REDUCE(embedding, 128,
+     * 'RANDOM_PROJECTION')
+     */
+    public static Object vectorRandomProjection(Object vectorData, Integer 
targetDimension) {
+        if (vectorData == null || targetDimension == null) {
+            return null;
+        }
+
+        Float[] sourceVector = convertToFloatArray(vectorData);
+        if (sourceVector.length <= targetDimension) {
+            return vectorData; // No need to reduce
+        }
+
+        float[][] projectionMatrix =
+                createGaussianProjectionMatrix(sourceVector.length, 
targetDimension);
+        Float[] result = applyProjection(sourceVector, projectionMatrix, 
targetDimension);
+        return VectorUtils.toByteBuffer(result);
+    }
+
+    /**
+     * Sparse random projection for dimension reduction Usage: 
VECTOR_REDUCE(embedding, 64,
+     * 'SPARSE_RANDOM_PROJECTION')
+     */
+    public static Object vectorSparseProjection(Object vectorData, Integer 
targetDimension) {
+        if (vectorData == null || targetDimension == null) {
+            return null;
+        }
+
+        Float[] sourceVector = convertToFloatArray(vectorData);
+        if (sourceVector.length <= targetDimension) {
+            return vectorData; // No need to reduce
+        }
+
+        float[][] projectionMatrix =
+                createSparseProjectionMatrix(sourceVector.length, 
targetDimension);
+        Float[] result = applyProjection(sourceVector, projectionMatrix, 
targetDimension);
+        return VectorUtils.toByteBuffer(result);
+    }
+
+    /**
+     * Generic vector dimension reduction function Usage: 
VECTOR_REDUCE(vector_field,
+     * target_dimension, method) method: 'TRUNCATE', 'RANDOM_PROJECTION', 
'SPARSE_RANDOM_PROJECTION'
+     */
+    public static Object vectorReduce(Object vectorData, Integer 
targetDimension, String method) {
+        if (vectorData == null || targetDimension == null || method == null) {
+            return null;
+        }
+
+        switch (method.toUpperCase()) {
+            case "TRUNCATE":
+                return vectorTruncate(vectorData, targetDimension);
+            case "RANDOM_PROJECTION":
+                return vectorRandomProjection(vectorData, targetDimension);
+            case "SPARSE_RANDOM_PROJECTION":
+                return vectorSparseProjection(vectorData, targetDimension);
+            default:
+                throw new IllegalArgumentException("Unknown reduction method: 
" + method);
+        }
+    }
+
+    /** Normalize vector to unit length Usage: VECTOR_NORMALIZE(vector_field) 
*/
+    public static Object vectorNormalize(Object vectorData) {
+        if (vectorData == null) {
+            return null;
+        }
+
+        Float[] vector = convertToFloatArray(vectorData);
+        double magnitude = 0.0;
+        for (Float value : vector) {
+            if (value != null) {
+                magnitude += value * value;
+            }
+        }
+        magnitude = Math.sqrt(magnitude);
+
+        if (magnitude == 0.0) {
+            return vectorData; // Return original if zero vector
+        }
+
+        Float[] normalized = new Float[vector.length];
+        for (int i = 0; i < vector.length; i++) {
+            normalized[i] = vector[i] == null ? null : (float) (vector[i] / 
magnitude);
+        }
+
+        return VectorUtils.toByteBuffer(normalized);
+    }
+
+    private static Float[] applyProjection(
+            Float[] sourceVector, float[][] projectionMatrix, int 
targetDimension) {
+        Float[] result = new Float[targetDimension];
+        for (int i = 0; i < targetDimension; i++) {
+            float sum = 0.0f;
+            for (int j = 0; j < sourceVector.length; j++) {
+                if (projectionMatrix[i][j] != 0 && sourceVector[j] != null) {
+                    sum += sourceVector[j] * projectionMatrix[i][j];
+                }
+            }
+            result[i] = sum;
+        }
+        return result;
+    }
+
+    private static float[][] createGaussianProjectionMatrix(
+            int sourceDimension, int targetDimension) {
+        float[][] matrix = new float[targetDimension][sourceDimension];
+        float scale = (float) Math.sqrt(1.0 / targetDimension);
+
+        for (int i = 0; i < targetDimension; i++) {
+            for (int j = 0; j < sourceDimension; j++) {
+                matrix[i][j] = (float) random.nextGaussian() * scale;
+            }
+        }
+        return matrix;
+    }
+
+    private static float[][] createSparseProjectionMatrix(
+            int sourceDimension, int targetDimension) {
+        float[][] matrix = new float[targetDimension][sourceDimension];
+        float scale = (float) Math.sqrt(3.0);
+        double p1 = 1.0 / 6.0;
+        double p2 = 2.0 / 6.0;
+
+        for (int i = 0; i < targetDimension; i++) {
+            for (int j = 0; j < sourceDimension; j++) {
+                double rand = random.nextDouble();
+                if (rand < p1) {
+                    matrix[i][j] = scale;
+                } else if (rand < p2) {
+                    matrix[i][j] = -scale;
+                } else {
+                    matrix[i][j] = 0;
+                }
+            }
+        }
+        return matrix;
+    }
 }
diff --git 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/SQLVectorFunctionTest.java
 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/SQLVectorFunctionTest.java
new file mode 100644
index 0000000000..e006331ad9
--- /dev/null
+++ 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/SQLVectorFunctionTest.java
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.sql;
+
+import org.apache.seatunnel.api.configuration.ReadonlyConfig;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
+import org.apache.seatunnel.api.table.catalog.TableIdentifier;
+import org.apache.seatunnel.api.table.catalog.TableSchema;
+import org.apache.seatunnel.api.table.type.BasicType;
+import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
+import org.apache.seatunnel.api.table.type.SeaTunnelRow;
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.VectorType;
+import org.apache.seatunnel.common.utils.VectorUtils;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+
+public class SQLVectorFunctionTest {
+
+    private static final String TEST_NAME = "vector_test";
+    private static final String[] FIELD_NAMES =
+            new String[] {"id", "vector_field", "vector_field2"};
+    private CatalogTable catalogTable;
+
+    @BeforeEach
+    void setUp() {
+        SeaTunnelRowType rowType =
+                new SeaTunnelRowType(
+                        FIELD_NAMES,
+                        new SeaTunnelDataType[] {
+                            BasicType.INT_TYPE,
+                            VectorType.VECTOR_FLOAT_TYPE,
+                            VectorType.VECTOR_FLOAT_TYPE
+                        });
+
+        TableSchema.Builder schemaBuilder = TableSchema.builder();
+        for (int i = 0; i < rowType.getTotalFields(); i++) {
+            PhysicalColumn column =
+                    PhysicalColumn.of(
+                            rowType.getFieldName(i), rowType.getFieldType(i), 
0, true, null, null);
+            schemaBuilder.column(column);
+        }
+
+        catalogTable =
+                CatalogTable.of(
+                        TableIdentifier.of(TEST_NAME, TEST_NAME, null, 
TEST_NAME),
+                        schemaBuilder.build(),
+                        new HashMap<>(),
+                        new ArrayList<>(),
+                        "Vector function test table");
+    }
+
+    @Test
+    public void testVectorTruncate() {
+        ReadonlyConfig config =
+                ReadonlyConfig.fromMap(
+                        Collections.singletonMap(
+                                "query",
+                                "SELECT id, VECTOR_REDUCE(vector_field, 
3,'TRUNCATE') as truncated_vector FROM dual"));
+
+        SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+        TableSchema tableSchema = sqlTransform.transformTableSchema();
+
+        // Create test data
+        Float[] sourceVector = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+        ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+        SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1, 
vectorBuffer, null});
+        List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+        Assertions.assertNotNull(result);
+        Assertions.assertEquals(1, result.size());
+
+        SeaTunnelRow outputRow = result.get(0);
+        Assertions.assertEquals(1, outputRow.getField(0));
+
+        ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+        Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+        Assertions.assertEquals(3, resultArray.length);
+        Assertions.assertEquals(1.0f, resultArray[0], 0.001f);
+        Assertions.assertEquals(2.0f, resultArray[1], 0.001f);
+        Assertions.assertEquals(3.0f, resultArray[2], 0.001f);
+    }
+
+    @Test
+    public void testVectorNormalize() {
+        ReadonlyConfig config =
+                ReadonlyConfig.fromMap(
+                        Collections.singletonMap(
+                                "query",
+                                "SELECT id, VECTOR_NORMALIZE(vector_field) as 
normalized_vector FROM dual"));
+
+        SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+
+        // Create test data: [3, 4] normalized should be [0.6, 0.8]
+        Float[] sourceVector = new Float[] {3.0f, 4.0f};
+        ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+        SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1, 
vectorBuffer, null});
+        List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+        Assertions.assertNotNull(result);
+        Assertions.assertEquals(1, result.size());
+
+        SeaTunnelRow outputRow = result.get(0);
+        Assertions.assertEquals(1, outputRow.getField(0));
+
+        ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+        Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+        Assertions.assertEquals(2, resultArray.length);
+        Assertions.assertEquals(0.6f, resultArray[0], 0.001f);
+        Assertions.assertEquals(0.8f, resultArray[1], 0.001f);
+    }
+
+    @Test
+    public void testVectorReduce() {
+        ReadonlyConfig config =
+                ReadonlyConfig.fromMap(
+                        Collections.singletonMap(
+                                "query",
+                                "SELECT id, VECTOR_REDUCE(vector_field, 3, 
'TRUNCATE') as reduced_vector FROM dual"));
+
+        SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+
+        // Create test data
+        Float[] sourceVector = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+        ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+        SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1, 
vectorBuffer, null});
+        List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+        Assertions.assertNotNull(result);
+        Assertions.assertEquals(1, result.size());
+
+        SeaTunnelRow outputRow = result.get(0);
+        Assertions.assertEquals(1, outputRow.getField(0));
+
+        ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+        Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+        Assertions.assertEquals(3, resultArray.length);
+        Assertions.assertEquals(1.0f, resultArray[0], 0.001f);
+        Assertions.assertEquals(2.0f, resultArray[1], 0.001f);
+        Assertions.assertEquals(3.0f, resultArray[2], 0.001f);
+    }
+
+    @Test
+    public void testVectorRandomProjection() {
+        ReadonlyConfig config =
+                ReadonlyConfig.fromMap(
+                        Collections.singletonMap(
+                                "query",
+                                "SELECT id, VECTOR_REDUCE(vector_field, 
3,'RANDOM_PROJECTION') as projected_vector FROM dual"));
+
+        SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+
+        // Create test data
+        Float[] sourceVector = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+        ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+        SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1, 
vectorBuffer, null});
+        List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+        Assertions.assertNotNull(result);
+        Assertions.assertEquals(1, result.size());
+
+        SeaTunnelRow outputRow = result.get(0);
+        Assertions.assertEquals(1, outputRow.getField(0));
+
+        ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+        Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+        Assertions.assertEquals(3, resultArray.length);
+
+        // Just verify that we got a result with the expected dimension
+        for (Float value : resultArray) {
+            Assertions.assertNotNull(value);
+        }
+    }
+
+    @Test
+    public void testVectorSparseProjection() {
+        ReadonlyConfig config =
+                ReadonlyConfig.fromMap(
+                        Collections.singletonMap(
+                                "query",
+                                "SELECT id, VECTOR_REDUCE(vector_field, 
3,'SPARSE_RANDOM_PROJECTION') as sparse_projected_vector FROM dual"));
+
+        SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+
+        // Create test data
+        Float[] sourceVector = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+        ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+        SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1, 
vectorBuffer, null});
+        List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+        Assertions.assertNotNull(result);
+        Assertions.assertEquals(1, result.size());
+
+        SeaTunnelRow outputRow = result.get(0);
+        Assertions.assertEquals(1, outputRow.getField(0));
+
+        ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+        Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+        Assertions.assertEquals(3, resultArray.length);
+
+        // Just verify that we got a result with the expected dimension
+        for (Float value : resultArray) {
+            Assertions.assertNotNull(value);
+        }
+    }
+}

Reply via email to