This is an automated email from the ASF dual-hosted git repository.

wanghailin pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 51ffc5a97e [Feature][Transform-v2] Add support for Zhipu AI in 
Embedding and LLM module (#8790)
51ffc5a97e is described below

commit 51ffc5a97e4f7a75d1df47ecba9039a52d59abae
Author: xiaochen <598457...@qq.com>
AuthorDate: Tue Feb 25 10:47:29 2025 +0800

    [Feature][Transform-v2] Add support for Zhipu AI in Embedding and LLM 
module (#8790)
---
 docs/en/transform-v2/embedding.md                  | 29 ++++++-----
 docs/en/transform-v2/llm.md                        |  2 +-
 docs/zh/transform-v2/embedding.md                  | 29 ++++++-----
 docs/zh/transform-v2/llm.md                        |  2 +-
 seatunnel-transforms-v2/pom.xml                    |  4 +-
 .../transform/nlpmodel/ModelProvider.java          |  3 ++
 .../transform/nlpmodel/ModelTransformConfig.java   |  3 ++
 .../nlpmodel/embedding/EmbeddingTransform.java     | 13 +++++
 .../embedding/EmbeddingTransformFactory.java       |  4 ++
 .../nlpmodel/embedding/remote/AbstractModel.java   | 14 ++---
 .../embedding/remote/custom/CustomModel.java       |  6 +--
 .../embedding/remote/doubao/DoubaoModel.java       | 10 ++--
 .../embedding/remote/openai/OpenAIModel.java       | 10 ++--
 .../embedding/remote/qianfan/QianfanModel.java     | 10 ++--
 .../DoubaoModel.java => zhipu/ZhipuModel.java}     | 60 ++++++++++++++--------
 .../transform/nlpmodel/llm/LLMTransform.java       |  1 +
 .../embedding/EmbeddingRequestJsonTest.java        | 25 ++++++++-
 17 files changed, 147 insertions(+), 78 deletions(-)

diff --git a/docs/en/transform-v2/embedding.md 
b/docs/en/transform-v2/embedding.md
index 350a23fc55..cbebd535b1 100644
--- a/docs/en/transform-v2/embedding.md
+++ b/docs/en/transform-v2/embedding.md
@@ -10,20 +10,21 @@ different API endpoints.
 
 ## Options
 
-| Name                           | Type   | Required | Default Value | 
Description                                                                     
                            |
-|--------------------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------|
-| model_provider                 | enum   | yes      | -             | The 
model provider for embedding. Options may include `QIANFAN`, `OPENAI`, etc.     
                        |
-| api_key                        | string | yes      | -             | The API 
key required to authenticate with the embedding service.                        
                    |
-| secret_key                     | string | yes      | -             | The 
secret key required for additional authentication with the embedding service.   
                        |
-| single_vectorized_input_number | int    | no       | 1             | The 
number of inputs vectorized in one request. Default is 1.                       
                        |
-| vectorization_fields           | map    | yes      | -             | A 
mapping between input fields and their corresponding output vector fields.      
                          |
-| model                          | string | yes      | -             | The 
specific model to use for embedding (e.g: `text-embedding-3-small` for OPENAI). 
                        |
-| api_path                       | string | no       | -             | The API 
endpoint for the embedding service. Typically provided by the model provider.   
                    |
-| oauth_path                     | string | no       | -             | The API 
endpoint for the oauth service.                                                 
                    |
-| custom_config                  | map    | no       |               | Custom 
configurations for the model.                                                   
                     |
-| custom_response_parse          | string | no       |               | 
Specifies how to parse the response from the model using JsonPath. Example: 
`$.choices[*].message.content`. |
-| custom_request_headers         | map    | no       |               | Custom 
headers for the request to the model.                                           
                     |
-| custom_request_body            | map    | no       |               | Custom 
body for the request. Supports placeholders like `${model}`, `${input}`.        
                     |
+| Name                             | Type   | Required | Default Value | 
Description                                                                     
                                                                                
        |
+|----------------------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| model_provider                   | enum   | yes      | -             | The 
model provider for embedding. Options may include `QIANFAN`, `OPENAI`, etc.     
                                                                                
    |
+| api_key                          | string | yes      | -             | The 
API key required to authenticate with the embedding service.                    
                                                                                
    |
+| secret_key                       | string | yes      | -             | The 
secret key required for additional authentication with the embedding service.   
                                                                                
    |
+| single_vectorized_input_number   | int    | no       | 1             | The 
number of inputs vectorized in one request. Default is 1.                       
                                                                                
    |
+| vectorization_fields             | map    | yes      | -             | A 
mapping between input fields and their corresponding output vector fields.      
                                                                                
      |
+| model                            | string | yes      | -             | The 
specific model to use for embedding (e.g: `text-embedding-3-small` for OPENAI). 
                                                                                
    |
+| api_path                         | string | no       | -             | The 
API endpoint for the embedding service. Typically provided by the model 
provider.                                                                       
            |
+| dimension                        | int    | no       | -             | TThe 
vector dimension defaults to 2048. The Embedding-3 model supports custom vector 
dimensions, and it is recommended to choose dimensions of 256, 512, 1024, or 
2048. |
+| oauth_path                       | string | no       | -             | The 
API endpoint for the oauth service.                                             
                                                                                
    |
+| custom_config                    | map    | no       |               | 
Custom configurations for the model.                                            
                                                                                
        |
+| custom_response_parse            | string | no       |               | 
Specifies how to parse the response from the model using JsonPath. Example: 
`$.choices[*].message.content`.                                                 
            |
+| custom_request_headers           | map    | no       |               | 
Custom headers for the request to the model.                                    
                                                                                
        |
+| custom_request_body              | map    | no       |               | 
Custom body for the request. Supports placeholders like `${model}`, `${input}`. 
                                                                                
        |
 
 ### model_provider
 
diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md
index 680121cb4d..0bc137bded 100644
--- a/docs/en/transform-v2/llm.md
+++ b/docs/en/transform-v2/llm.md
@@ -28,7 +28,7 @@ more.
 ### model_provider
 
 The model provider to use. The available options are:
-OPENAI, DOUBAO, DEEPSEEK, KIMIAI, MICROSOFT, CUSTOM
+OPENAI, DOUBAO, DEEPSEEK, KIMIAI, MICROSOFT, ZHIPU, CUSTOM
 
 > tips: If you use Microsoft, please make sure api_path cannot be empty
 
diff --git a/docs/zh/transform-v2/embedding.md 
b/docs/zh/transform-v2/embedding.md
index e05c9c2442..8ea2b68f70 100644
--- a/docs/zh/transform-v2/embedding.md
+++ b/docs/zh/transform-v2/embedding.md
@@ -8,20 +8,21 @@
 
 ## 配置选项
 
-| 名称                             | 类型     | 是否必填 | 默认值 | 描述                    
                                           |
-|--------------------------------|--------|------|-----|------------------------------------------------------------------|
-| model_provider                 | enum   | 是    | -   | embedding模型的提供商。可选项包括 
`QIANFAN`、`OPENAI` 等。                      |
-| api_key                        | string | 是    | -   | 
用于验证embedding服务的API密钥。                                           |
-| secret_key                     | string | 是    | -   | 
用于额外验证的密钥。一些提供商可能需要此密钥进行安全的API请求。                                |
-| single_vectorized_input_number | int    | 否    | 1   | 单次请求向量化的输入数量。默认值为1。   
                                           |
-| vectorization_fields           | map    | 是    | -   | 输入字段和相应的输出向量字段之间的映射。  
                                           |
-| model                          | string | 是    | -   | 
要使用的具体embedding模型。例如,如果提供商为OPENAI,可以指定 `text-embedding-3-small`。 |
-| api_path                       | string | 否    | -   | 
embedding服务的API。通常由模型提供商提供。                                      |
-| oauth_path                     | string | 否    | -   | oauth 服务的 API 。       
                                           |
-| custom_config                  | map    | 否    |     | 模型的自定义配置。             
                                           |
-| custom_response_parse          | string | 否    |     | 使用 JsonPath 
解析模型响应的方式。示例:`$.choices[*].message.content`。         |
-| custom_request_headers         | map    | 否    |     | 发送到模型的请求的自定义头信息。      
                                           |
-| custom_request_body            | map    | 否    |     | 请求体的自定义配置。支持占位符如 
`${model}`、`${input}`。                          |
+| 名称                               | 类型     | 是否必填 | 默认值    | 描述               
                                                  |
+|----------------------------------|--------|------|--------|--------------------------------------------------------------------|
+| model_provider                   | enum   | 是    | -      | 
embedding模型的提供商。可选项包括 `QIANFAN`、`OPENAI` 等。                        |
+| api_key                          | string | 是    | -      | 
用于验证embedding服务的API密钥。                                             |
+| secret_key                       | string | 是    | -      | 
用于额外验证的密钥。一些提供商可能需要此密钥进行安全的API请求。                                  |
+| single_vectorized_input_number   | int    | 否    | 1      | 
单次请求向量化的输入数量。默认值为1。                                                |
+| vectorization_fields             | map    | 是    | -      | 
输入字段和相应的输出向量字段之间的映射。                                               |
+| model                            | string | 是    | -      | 
要使用的具体embedding模型。例如,如果提供商为OPENAI,可以指定 `text-embedding-3-small`。   |
+| api_path                         | string | 否    | -      | 
embedding服务的API。通常由模型提供商提供。                                        |
+| dimension                        | int    | 否    | 2048   | 向量维度默认为 
2048,Embedding-3模型支持自定义向量维度,建议选择256、512、1024或2048维度。       |
+| oauth_path                       | string | 否    | -      | oauth 服务的 API 。  
                                                  |
+| custom_config                    | map    | 否    |        | 模型的自定义配置。        
                                                  |
+| custom_response_parse            | string | 否    |        | 使用 JsonPath 
解析模型响应的方式。示例:`$.choices[*].message.content`。           |
+| custom_request_headers           | map    | 否    |        | 发送到模型的请求的自定义头信息。 
                                                  |
+| custom_request_body              | map    | 否    |        | 请求体的自定义配置。支持占位符如 
`${model}`、`${input}`。                            |
 
 ### embedding_model_provider
 
diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md
index c1d05d59a3..c6cead3dfd 100644
--- a/docs/zh/transform-v2/llm.md
+++ b/docs/zh/transform-v2/llm.md
@@ -26,7 +26,7 @@
 ### model_provider
 
 要使用的模型提供者。可用选项为:
-OPENAI,DOUBAO,DEEPSEEK,KIMIAI,MICROSOFT, CUSTOM
+OPENAI,DOUBAO,DEEPSEEK,KIMIAI,MICROSOFT, ZHIPU, CUSTOM
 
 > tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空
 
diff --git a/seatunnel-transforms-v2/pom.xml b/seatunnel-transforms-v2/pom.xml
index f15c1aae40..5f74ad156b 100644
--- a/seatunnel-transforms-v2/pom.xml
+++ b/seatunnel-transforms-v2/pom.xml
@@ -32,6 +32,8 @@
     <properties>
         <httpclient.version>4.5.13</httpclient.version>
         <httpcore.version>4.4.4</httpcore.version>
+        <mockwebserver.version>3.6.0</mockwebserver.version>
+        <zhipu.version>release-V4-2.3.0</zhipu.version>
     </properties>
 
     <dependencyManagement>
@@ -95,7 +97,7 @@
         <dependency>
             <groupId>com.squareup.okhttp3</groupId>
             <artifactId>mockwebserver</artifactId>
-            <version>3.6.0</version>
+            <version>${mockwebserver.version}</version>
             <scope>test</scope>
         </dependency>
     </dependencies>
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
index f18ffdfc8e..aaeaee90ad 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
@@ -28,6 +28,9 @@ public enum ModelProvider {
     KIMIAI("https://api.moonshot.cn/v1/chat/completions";, ""),
     DEEPSEEK("https://api.deepseek.com/chat/completions";, ""),
     MICROSOFT("", ""),
+    ZHIPU(
+            "https://open.bigmodel.cn/api/paas/v4/chat/completions";,
+            "https://open.bigmodel.cn/api/paas/v4/embeddings";),
     CUSTOM("", ""),
     LOCAL("", "");
 
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java
index b123459750..c3709c70b6 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java
@@ -79,6 +79,9 @@ public class ModelTransformConfig implements Serializable {
                     .withFallbackKeys("inference_batch_size")
                     .withDescription("The row batch size of each process");
 
+    public static final Option<Integer> DIMENSION =
+            
Options.key("dimension").intType().defaultValue(2048).withDescription("dimension");
+
     public static class CustomRequestConfig {
 
         // Custom response parsing
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
index c699c6bfe8..6a8729a198 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
@@ -33,6 +33,7 @@ import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomMod
 import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
 import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel;
 import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel;
+import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu.ZhipuModel;
 import org.apache.seatunnel.transform.nlpmodel.llm.LLMTransformConfig;
 
 import lombok.NonNull;
@@ -136,6 +137,18 @@ public class EmbeddingTransform extends 
MultipleFieldOutputTransform {
                                             EmbeddingTransformConfig
                                                     
.SINGLE_VECTORIZED_INPUT_NUMBER));
                     break;
+                case ZHIPU:
+                    model =
+                            new ZhipuModel(
+                                    config.get(ModelTransformConfig.API_KEY),
+                                    config.get(ModelTransformConfig.MODEL),
+                                    provider.usedEmbeddingPath(
+                                            
config.get(ModelTransformConfig.API_PATH)),
+                                    config.get(ModelTransformConfig.DIMENSION),
+                                    config.get(
+                                            EmbeddingTransformConfig
+                                                    
.SINGLE_VECTORIZED_INPUT_NUMBER));
+                    break;
                 case LOCAL:
                 default:
                     throw new IllegalArgumentException("Unsupported model 
provider: " + provider);
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java
index 56e252e1bb..5f8e397e69 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java
@@ -62,6 +62,10 @@ public class EmbeddingTransformFactory implements 
TableTransformFactory {
                         LLMTransformConfig.MODEL_PROVIDER,
                         ModelProvider.CUSTOM,
                         LLMTransformConfig.CustomRequestConfig.CUSTOM_CONFIG)
+                .conditional(
+                        EmbeddingTransformConfig.MODEL_PROVIDER,
+                        ModelProvider.ZHIPU,
+                        EmbeddingTransformConfig.DIMENSION)
                 .optional(TransformCommonOptions.MULTI_TABLES)
                 .optional(TransformCommonOptions.TABLE_MATCH_REGEX)
                 .build();
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
index 0803dfd7ad..a53fa4684c 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
@@ -42,23 +42,23 @@ public abstract class AbstractModel implements Model {
     public List<ByteBuffer> vectorization(Object[] fields) throws IOException {
         List<ByteBuffer> result = new ArrayList<>();
 
-        List<List<Float>> vectors = batchProcess(fields, 
singleVectorizedInputNumber);
-        for (List<Float> vector : vectors) {
-            result.add(BufferUtils.toByteBuffer(vector.toArray(new Float[0])));
+        List<List<Double>> vectors = batchProcess(fields, 
singleVectorizedInputNumber);
+        for (List<Double> vector : vectors) {
+            result.add(BufferUtils.toByteBuffer(vector.toArray(new 
Double[0])));
         }
         return result;
     }
 
-    protected abstract List<List<Float>> vector(Object[] fields) throws 
IOException;
+    protected abstract List<List<Double>> vector(Object[] fields) throws 
IOException;
 
-    public List<List<Float>> batchProcess(Object[] array, int batchSize) 
throws IOException {
-        List<List<Float>> merged = new ArrayList<>();
+    public List<List<Double>> batchProcess(Object[] array, int batchSize) 
throws IOException {
+        List<List<Double>> merged = new ArrayList<>();
         if (array == null || array.length == 0) {
             return merged;
         }
         for (int i = 0; i < array.length; i += batchSize) {
             Object[] batch = ArrayUtils.subarray(array, i, i + batchSize);
-            List<List<Float>> vector = vector(batch);
+            List<List<Double>> vector = vector(batch);
             merged.addAll(vector);
         }
         if (array.length != merged.size()) {
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
index 179315f956..ea39f15462 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
@@ -67,7 +67,7 @@ public class CustomModel extends AbstractModel {
     }
 
     @Override
-    protected List<List<Float>> vector(Object[] fields) throws IOException {
+    protected List<List<Double>> vector(Object[] fields) throws IOException {
         return vectorGeneration(fields);
     }
 
@@ -76,7 +76,7 @@ public class CustomModel extends AbstractModel {
         return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
     }
 
-    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
         HttpPost post = new HttpPost(apiPath);
         // Construct a request with custom parameters
         for (Map.Entry<String, String> entry : header.entrySet()) {
@@ -96,7 +96,7 @@ public class CustomModel extends AbstractModel {
         }
 
         return OBJECT_MAPPER.convertValue(
-                parseResponse(responseStr), new 
TypeReference<List<List<Float>>>() {});
+                parseResponse(responseStr), new 
TypeReference<List<List<Double>>>() {});
     }
 
     @VisibleForTesting
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
index f2b1e348c7..1591cd587e 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
@@ -54,7 +54,7 @@ public class DoubaoModel extends AbstractModel {
     }
 
     @Override
-    protected List<List<Float>> vector(Object[] fields) throws IOException {
+    protected List<List<Double>> vector(Object[] fields) throws IOException {
         return vectorGeneration(fields);
     }
 
@@ -63,7 +63,7 @@ public class DoubaoModel extends AbstractModel {
         return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
     }
 
-    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
         HttpPost post = new HttpPost(apiPath);
         post.setHeader("Authorization", "Bearer " + apiKey);
         post.setHeader("Content-Type", "application/json");
@@ -82,14 +82,14 @@ public class DoubaoModel extends AbstractModel {
         }
 
         JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
-        List<List<Float>> embeddings = new ArrayList<>();
+        List<List<Double>> embeddings = new ArrayList<>();
 
         if (data.isArray()) {
             for (JsonNode node : data) {
                 JsonNode embeddingNode = node.get("embedding");
-                List<Float> embedding =
+                List<Double> embedding =
                         OBJECT_MAPPER.readValue(
-                                embeddingNode.traverse(), new 
TypeReference<List<Float>>() {});
+                                embeddingNode.traverse(), new 
TypeReference<List<Double>>() {});
                 embeddings.add(embedding);
             }
         }
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
index 2a45cc829f..467d6cb406 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
@@ -53,7 +53,7 @@ public class OpenAIModel extends AbstractModel {
     }
 
     @Override
-    protected List<List<Float>> vector(Object[] fields) throws IOException {
+    protected List<List<Double>> vector(Object[] fields) throws IOException {
         if (fields.length > 1) {
             throw new IllegalArgumentException("OpenAI model only supports 
single input");
         }
@@ -65,7 +65,7 @@ public class OpenAIModel extends AbstractModel {
         return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
     }
 
-    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
         HttpPost post = new HttpPost(apiPath);
         post.setHeader("Authorization", "Bearer " + apiKey);
         post.setHeader("Content-Type", "application/json");
@@ -84,14 +84,14 @@ public class OpenAIModel extends AbstractModel {
         }
 
         JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
-        List<List<Float>> embeddings = new ArrayList<>();
+        List<List<Double>> embeddings = new ArrayList<>();
 
         if (data.isArray()) {
             for (JsonNode node : data) {
                 JsonNode embeddingNode = node.get("embedding");
-                List<Float> embedding =
+                List<Double> embedding =
                         OBJECT_MAPPER.readValue(
-                                embeddingNode.traverse(), new 
TypeReference<List<Float>>() {});
+                                embeddingNode.traverse(), new 
TypeReference<List<Double>>() {});
                 embeddings.add(embedding);
             }
         }
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
index f85619eb3e..67c1a8147a 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
@@ -100,7 +100,7 @@ public class QianfanModel extends AbstractModel {
     }
 
     @Override
-    public List<List<Float>> vector(Object[] fields) throws IOException {
+    public List<List<Double>> vector(Object[] fields) throws IOException {
         return vectorGeneration(fields);
     }
 
@@ -109,7 +109,7 @@ public class QianfanModel extends AbstractModel {
         return vectorGeneration(new Object[] 
{DIMENSION_EXAMPLE}).get(0).size();
     }
 
-    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
         String formattedApiPath =
                 String.format(
                         (apiPath.endsWith("/") ? apiPath : apiPath + "/") + 
"%s?access_token=%s",
@@ -143,14 +143,14 @@ public class QianfanModel extends AbstractModel {
                     "Failed to get vector from qianfan, response: " + 
result.get("error_msg"));
         }
 
-        List<List<Float>> embeddings = new ArrayList<>();
+        List<List<Double>> embeddings = new ArrayList<>();
         JsonNode data = result.get("data");
         if (data.isArray()) {
             for (JsonNode node : data) {
-                List<Float> embedding =
+                List<Double> embedding =
                         OBJECT_MAPPER.readValue(
                                 node.get("embedding").traverse(),
-                                new TypeReference<List<Float>>() {});
+                                new TypeReference<List<Double>>() {});
                 embeddings.add(embedding);
             }
         }
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
similarity index 68%
copy from 
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
copy to 
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
index f2b1e348c7..df72261bb5 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao;
+package org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu;
 
 import 
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
 import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
@@ -25,6 +25,7 @@ import 
org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTestin
 
 import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;
 
+import org.apache.http.HttpHeaders;
 import org.apache.http.client.config.RequestConfig;
 import org.apache.http.client.methods.CloseableHttpResponse;
 import org.apache.http.client.methods.HttpPost;
@@ -34,62 +35,77 @@ import org.apache.http.impl.client.HttpClients;
 import org.apache.http.util.EntityUtils;
 
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
-public class DoubaoModel extends AbstractModel {
+/** Zhipu model. Refer <a 
href="https://bigmodel.cn/dev/api/vector/embedding";>embedding api </a> */
+public class ZhipuModel extends AbstractModel {
 
     private final CloseableHttpClient client;
-    private final String apiKey;
     private final String model;
-    private final String apiPath;
-
-    public DoubaoModel(String apiKey, String model, String apiPath, Integer 
vectorizedNumber) {
+    private final String apiKey;
+    private final String apiPath;;
+    private final Integer dimension;
+    private final Integer MAX_INPUT_SIZE = 64;
+
+    public ZhipuModel(
+            String apiKey,
+            String model,
+            String apiPath,
+            Integer dimension,
+            Integer vectorizedNumber)
+            throws IOException {
         super(vectorizedNumber);
-        this.apiKey = apiKey;
         this.model = model;
+        this.apiKey = apiKey;
         this.apiPath = apiPath;
+        this.dimension = dimension;
         this.client = HttpClients.createDefault();
     }
 
     @Override
-    protected List<List<Float>> vector(Object[] fields) throws IOException {
+    public List<List<Double>> vector(Object[] fields) throws IOException {
         return vectorGeneration(fields);
     }
 
     @Override
     public Integer dimension() throws IOException {
-        return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
+        return dimension;
     }
 
-    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
+
+        if (fields == null || fields.length > MAX_INPUT_SIZE) {
+            throw new IOException(
+                    "Zhipu input text for vectorization, with a maximum limit 
of 64 entries.");
+        }
         HttpPost post = new HttpPost(apiPath);
-        post.setHeader("Authorization", "Bearer " + apiKey);
-        post.setHeader("Content-Type", "application/json");
+        post.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey);
+        post.setHeader(HttpHeaders.CONTENT_TYPE, "application/json");
         post.setConfig(
                 
RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
 
         post.setEntity(
                 new StringEntity(
-                        
OBJECT_MAPPER.writeValueAsString(createJsonNodeFromData(fields)), "UTF-8"));
+                        
OBJECT_MAPPER.writeValueAsString(createJsonNodeFromData(fields)),
+                        StandardCharsets.UTF_8.name()));
 
         CloseableHttpResponse response = client.execute(post);
         String responseStr = EntityUtils.toString(response.getEntity());
-
         if (response.getStatusLine().getStatusCode() != 200) {
-            throw new IOException("Failed to get vector from doubao, response: 
" + responseStr);
+            throw new IOException("Failed to get vector from zhipu, response: 
" + responseStr);
         }
-
         JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
-        List<List<Float>> embeddings = new ArrayList<>();
+        List<List<Double>> embeddings = new ArrayList<>();
 
         if (data.isArray()) {
             for (JsonNode node : data) {
                 JsonNode embeddingNode = node.get("embedding");
-                List<Float> embedding =
+                List<Double> embedding =
                         OBJECT_MAPPER.readValue(
-                                embeddingNode.traverse(), new 
TypeReference<List<Float>>() {});
+                                embeddingNode.traverse(), new 
TypeReference<List<Double>>() {});
                 embeddings.add(embedding);
             }
         }
@@ -99,7 +115,11 @@ public class DoubaoModel extends AbstractModel {
     @VisibleForTesting
     public ObjectNode createJsonNodeFromData(Object[] fields) {
         ArrayNode arrayNode = OBJECT_MAPPER.valueToTree(Arrays.asList(fields));
-        return OBJECT_MAPPER.createObjectNode().put("model", 
model).set("input", arrayNode);
+        return OBJECT_MAPPER
+                .createObjectNode()
+                .put("model", model)
+                .put("dimensions", dimension)
+                .set("input", arrayNode);
     }
 
     @Override
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
index 048e1bffce..8160cdc647 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
@@ -109,6 +109,7 @@ public class LLMTransform extends 
SingleFieldOutputTransform {
             case DEEPSEEK:
             case OPENAI:
             case DOUBAO:
+            case ZHIPU:
                 model =
                         new OpenAIModel(
                                 inputCatalogTable.getSeaTunnelRowType(),
diff --git 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java
 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java
index dc43cdeb23..46c893d182 100644
--- 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java
+++ 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java
@@ -26,6 +26,7 @@ import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomMod
 import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
 import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel;
 import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel;
+import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu.ZhipuModel;
 
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
@@ -93,6 +94,26 @@ public class EmbeddingRequestJsonTest {
         model.close();
     }
 
+    @Test
+    void testZhipuRequestJson() throws IOException {
+        ZhipuModel model =
+                new ZhipuModel(
+                        "apikey",
+                        "modelName",
+                        "https://open.bigmodel.cn/api/paas/v4/embeddings";,
+                        64,
+                        1);
+        ObjectNode node =
+                model.createJsonNodeFromData(
+                        new Object[] {
+                            "Determine whether someone is Chinese or American 
by their name"
+                        });
+        Assertions.assertEquals(
+                
"{\"model\":\"modelName\",\"dimensions\":64,\"input\":[\"Determine whether 
someone is Chinese or American by their name\"]}",
+                OBJECT_MAPPER.writeValueAsString(node));
+        model.close();
+    }
+
     @Test
     void testCustomRequestJson() throws IOException {
         Map<String, String> header = new HashMap<>();
@@ -131,11 +152,11 @@ public class EmbeddingRequestJsonTest {
                         new HashMap<>(),
                         "$.data[*].embedding",
                         1);
-        List<List<Float>> lists =
+        List<List<Double>> lists =
                 OBJECT_MAPPER.convertValue(
                         customModel.parseResponse(
                                 
"{\"created\":1725001256,\"id\":\"02172500125677376580aba8475a41c550bbf05104842f0405ef5\",\"data\":[{\"embedding\":[-1.625,0.07958984375,-1.5703125,-3.03125,-1.4609375,3.46875,-0.73046875,-2.578125,-0.66796875,1.71875,0.361328125,2,5.125,2.25,4.6875,1.4921875,-0.77734375,-0.466796875,0.0439453125,-2.46875,3.59375,4.96875,2.34375,-5.34375,0.11083984375,-5.875,3.0625,4.09375,3.4375,0.2265625,9,-1.9296875,2.25,0.765625,3.671875,-2.484375,-1.171875,-1.6171875,
 [...]
-                        new TypeReference<List<List<Float>>>() {});
+                        new TypeReference<List<List<Double>>>() {});
         Assertions.assertEquals(2, lists.size());
         Assertions.assertEquals(2560, lists.get(0).size());
     }


Reply via email to